def pdf(self, x): polar_coords = self.to_polar(x) jacdet = np.abs(np.linalg.det(jacfwd(self.to_cartesian)(polar_coords))) r, _ = polar_coords pr = stats.norm.pdf(r, loc=self.radius, scale=np.sqrt(self.var)) return pr * 1 / (2 * np.pi) * 1 / jacdet
def __init__(self): self.min_bounds=-1.0 self.max_bounds=1.0 self.m=1.0 self.l=1.0 self.g=9.80665 self.state_size = 3 self.action_size = 1 self.pendulum_length = 1.0 self.x_goal = self.augment_state(np.array([0.0, 0.0])) self.initial_state = self.augment_state(np.array([np.pi, 0.0])) self.h = 300 self.dt = 0.02 self.viewer = None @jax.jit def _dynamics(input_val): diff = (self.max_bounds - self.min_bounds) / 2.0 mean = (self.max_bounds + self.min_bounds) / 2.0 x, u = input_val u = diff * np.tanh(u) + mean sin_theta = x[0] cos_theta = x[1] theta_dot = x[2] torque = u[0] theta = np.arctan2(sin_theta, cos_theta) theta_dot_dot = -3.0*self.g/(2*self.l)*np.sin(theta+np.pi) theta_dot_dot += 3.0 / (self.m * self.l**2) * torque next_theta = theta + theta_dot * self.dt return np.array([np.sin(next_theta), np.cos(next_theta), theta_dot + theta_dot_dot * self.dt]) def _dynamics_real(input_val): diff = (self.max_bounds - self.min_bounds) / 2.0 mean = (self.max_bounds + self.min_bounds) / 2.0 x, u = input_val u = diff * np.tanh(u) + mean sin_theta = x[0] cos_theta = x[1] theta_dot = x[2] torque = u[0] theta = np.arctan2(sin_theta, cos_theta) theta_dot_dot = -3.0*self.g/(2*self.l)*np.sin(theta+np.pi) theta_dot_dot += 3.0 / (1.1 * 1.1**2) * torque next_theta = theta + theta_dot * self.dt return np.array([np.sin(next_theta), np.cos(next_theta), theta_dot + theta_dot_dot * self.dt]) self._dynamics = _dynamics self._dynamics_real = _dynamics_real self._dynamics_der = jax.jit(jax.jacfwd(_dynamics)) self._dynamics_real_der = jax.jit(jax.jacfwd(_dynamics_real)) self.Q = np.eye(self.state_size) # self.Q = jax.ops.index_update(self.Q,(0, 1),self.pendulum_length) # self.Q = jax.ops.index_update(self.Q,(1, 0),self.pendulum_length) # self.Q = jax.ops.index_update(self.Q,(0, 0),self.pendulum_length**2) # self.Q = jax.ops.index_update(self.Q,(1, 1),self.pendulum_length**2) self.Q = jax.ops.index_update(self.Q,(2, 2),0.0) self.Q_terminal = 100 * np.eye(self.state_size) self.R = np.array([[0.1]]) def _costval(x, u, i): #print("asa") if i == self.h: return (x-self.x_goal)[email protected]_terminal@(x-self.x_goal) else: return (x-self.x_goal)[email protected]@(x-self.x_goal) + [email protected]@u def _costgrad(x,u,i): if i==self.h: return [2*self.Q_terminal@(x - self.x_goal), np.zeros((1,)), 2*self.Q_terminal, np.zeros((self.action_size, self.state_size)), np.zeros((self.action_size, self.action_size))] else: return [2*self.Q@(x-self.x_goal), 2*self.R@u, 2*self.Q, np.zeros((self.action_size,self.state_size)), 2*self.R] self._cost = _costval self._costgrad = _costgrad
def __init__(self, placebo=0, autodiff=True, method="3-point"): self.nsamples = 0 self.m, self.l, self.g, self.dt, self.H = 0.1, 0.2, 9.81, 0.05, 100 self.initial_state, self.goal_state, self.goal_action = ( jnp.array([1.0, 1.0, 0.0, 0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), jnp.array([self.m * self.g / 2.0, self.m * self.g / 2.0]), ) self.goal_action = jnp.hstack((self.goal_action, jnp.zeros(placebo))) self.viewer = None self.action_dim, self.state_dim = 2 + placebo, 6 self.last_u = jnp.zeros((2, )) def f(x, u): self.nsamples += 1 state = x x, y, th, xdot, ydot, thdot = state u1, u2 = u[:2] m, g, l, dt = self.m, self.g, self.l, self.dt xddot = -(u1 + u2) * jnp.sin(th) / m yddot = (u1 + u2) * jnp.cos(th) / m - g thddot = l * (u2 - u1) / (m * l**2) state_dot = jnp.array([xdot, ydot, thdot, xddot, yddot, thddot]) new_state = state + state_dot * dt return new_state def c(x, u): return 0.1 * (u - self.goal_action) @ (u - self.goal_action) + ( x - self.goal_state) @ (x - self.goal_state) def f_x(x, u): return approx_derivative(lambda x: f(x, u), x, method=method) def f_u(x, u): return approx_derivative(lambda u: f(x, u), u, method=method) def c_x(x, u): return 2 * (x - self.goal_state) def c_u(x, u): return 2 * 0.1 * (u - self.goal_action) def c_xx(x, u): return 2 * jnp.eye(self.state_dim) def c_uu(x, u): return 2 * 0.1 * jnp.eye(self.action_dim) if autodiff: self.f, self.f_x, self.f_u = ( f, jax.jacfwd(f, argnums=0), jax.jacfwd(f, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) else: self.f, self.f_x, self.f_u = f, f_x, f_u self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = c, c_x, c_u, c_xx, c_uu
def derivative(molecule, basis_name, method, order=1): """ Convenience function for computing the full nuclear derivative tensor at some order for a particular energy method, molecule, and basis set. May be memory-intensive. For gradients, choose order=1, hessian order=2, cubic derivative tensor order=3, quartic order = 4. Anything higher order derivatives should use the partial derivative utility. """ geom2d = np.asarray(molecule.geometry()) geom = jnp.asarray(geom2d.flatten()) mult = molecule.multiplicity() charge = molecule.molecular_charge() nuclear_charges = jnp.asarray( [molecule.charge(i) for i in range(geom2d.shape[0])]) xyz_file_name = "geom.xyz" # Save xyz file, get path molecule.save_xyz_file(xyz_file_name, True) xyz_path = os.path.abspath(os.getcwd()) + "/" + xyz_file_name #basis_dict = build_basis_set(molecule, basis_name) dim = geom.reshape(-1).shape[0] # Get number of basis functions basis_set = psi4.core.BasisSet.build(molecule, 'BASIS', basis_name, puream=0) nbf = basis_set.nbf() #TODO TODO TODO: support internal coordinate wrapper function. # This will take in internal coordinates, transform them into cartesians, and then compute integrals, energy # JAX will then collect the internal coordinate derivative tensor instead. # Define function 'energy' depending on requested method # TODO add 'order' argument if method == 'scf' or method == 'hf' or method == 'rhf': def electronic_energy(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False): return restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False) elif method == 'mp2': def electronic_energy(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False): return restricted_mp2(geom, basis_name, xyz_path, nuclear_charges, charge) elif method == 'ccsd': def electronic_energy(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False): return rccsd(geom, basis_name, xyz_path, nuclear_charges, charge) elif method == 'ccsd(t)': def electronic_energy(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False): return rccsd_t(geom, basis_name, xyz_path, nuclear_charges, charge) else: print( "Desired electronic structure method not understood. Use 'scf' 'hf' 'mp2' 'ccsd' or 'ccsd(t)' " ) # Now compile and compute differentiated energy function if order == 1: grad = jacfwd(electronic_energy, 0)(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False) deriv = jnp.round(grad, 10) elif order == 2: hess = jacfwd(jacfwd(electronic_energy, 0))(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False) deriv = jnp.round(hess.reshape(dim, dim), 10) elif order == 3: cubic = jacfwd(jacfwd(jacfwd(electronic_energy, 0)))(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False) deriv = jnp.round(cubic.reshape(dim, dim, dim), 10) elif order == 4: quartic = jacfwd(jacfwd(jacfwd(jacfwd(electronic_energy, 0))))(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False) deriv = jnp.round(quartic.reshape(dim, dim, dim, dim), 10) return np.asarray(deriv)
def write_integrals(molecule, basis_name, order, address): """ Writes all required (TODO only for diagonal) partial of one and two electron derivatives to disk using Quax integrals code. Temporary function to write all needed integrals to disk. Goal: Benchmark partial derivatives with address = (2,2,...,2) up to quartic. only need derivative vectors [0,0,1,0,0,0,...] [0,0,2,0,0,0,...] [0,0,3,0,0,0,...] [0,0,4,0,0,0,...] """ geom = jnp.asarray(np.asarray(molecule.geometry())) geom_list = np.asarray(molecule.geometry()).reshape(-1).tolist() mult = molecule.multiplicity() charge = molecule.molecular_charge() nuclear_charges = jnp.asarray( [molecule.charge(i) for i in range(geom.shape[0])]) basis_dict = build_basis_set(molecule, basis_name) kwargs = {"basis_dict": basis_dict, "nuclear_charges": nuclear_charges} def oei_wrapper(*args, **kwargs): geom = jnp.asarray(args) basis_dict = kwargs['basis_dict'] nuclear_charges = kwargs['nuclear_charges'] S, T, V = oei.oei_arrays(geom.reshape(-1, 3), basis_dict, nuclear_charges) return S, T, V def tei_wrapper(*args, **kwargs): geom = jnp.asarray(args) basis_dict = kwargs['basis_dict'] nuclear_charges = kwargs['nuclear_charges'] G = tei.tei_array(geom.reshape(-1, 3), basis_dict) return G # TODO can these be passed list of lists to generate all partials at a given order? if order == 1: i = address[0] dS, dT, dV = jacfwd(oei_wrapper, i)(*geom_list, **kwargs) dG = jacfwd(tei_wrapper, i)(*geom_list, **kwargs) elif order == 2: i, j = address[0], address[1] dS, dT, dV = jacfwd(jacfwd(oei_wrapper, i), j)(*geom_list, **kwargs) dG = jacfwd(jacfwd(tei_wrapper, i), j)(*geom_list, **kwargs) elif order == 3: i, j, k = address[0], address[1], address[2] dS, dT, dV = jacfwd(jacfwd(jacfwd(oei_wrapper, i), j), k)(*geom_list, **kwargs) dG = jacfwd(jacfwd(jacfwd(tei_wrapper, i), j), k)(*geom_list, **kwargs) elif order == 4: i, j, k, l = address[0], address[1], address[2], address[3] dS, dT, dV = jacfwd(jacfwd(jacfwd(jacfwd(oei_wrapper, i), j), k), l)(*geom_list, **kwargs) dG = jacfwd(jacfwd(jacfwd(jacfwd(tei_wrapper, i), j), k), l)(*geom_list, **kwargs) elif order == 5: i, j, k, l, m = address[0], address[1], address[2], address[ 3], address[4] dS, dT, dV = jacfwd( jacfwd(jacfwd(jacfwd(jacfwd(oei_wrapper, i), j), k), l), m)(*geom_list, **kwargs) dG = jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(tei_wrapper, i), j), k), l), m)(*geom_list, **kwargs) elif order == 6: i, j, k, l, m, n = address[0], address[1], address[2], address[ 3], address[4], address[5] dS, dT, dV = jacfwd( jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(oei_wrapper, i), j), k), l), m), n)(*geom_list, **kwargs) dG = jacfwd( jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(tei_wrapper, i), j), k), l), m), n)(*geom_list, **kwargs) else: print( "Error: Order {} partial derivatives are not exposed to the API.". format(order)) # Convert address tuple to (NCART,) derivative vector deriv_vec = [0] * len(geom_list) for i in address: deriv_vec[i] += 1 deriv_vec = np.asarray(deriv_vec) # Get flattened upper triangle index for this derivative vector flat_idx = get_deriv_vec_idx(deriv_vec) # Write to HDF5 # Open the h5py file without deleting all contents ('a' instead of 'a') # and create a dataset # Write this set of partial derivatives of integrals to disk. f = h5py.File("oei_partials.h5", "a") f.create_dataset("overlap_deriv" + str(order) + "_" + str(flat_idx), data=dS) f.create_dataset("kinetic_deriv" + str(order) + "_" + str(flat_idx), data=dT) f.create_dataset("potential_deriv" + str(order) + "_" + str(flat_idx), data=dV) f.close() f = h5py.File("eri_partials.h5", "a") f.create_dataset("eri_deriv" + str(order) + "_" + str(flat_idx), data=dG) f.close() return 0
def costate_dynamics(self, state, costate, control, homotopy, *params): return -jacfwd(self.hamiltonian)(state, costate, control, homotopy, * params)
def jacobian_risk_concentration_vector(self): return jit(jacfwd(self.risk_concentration_vector))
def __init__(self, parameter=0.0): self.H, self.dt, self.state_dim, self.action_dim = 200, 0.01, 6, 2 self.m1, self.m2, self.l1, self.l2, self.g = 1.0, 1.0, 1.0, 1.0, parameter self.initial_th, self.goal_coord = (np.pi / 4, np.pi / 2), np.array([0.0, 1.8]) self.initial_state = np.array([ *self.initial_th, 0.0, 0.0, self.l1 * np.cos(self.initial_th[0]) + self.l2 * np.cos(self.initial_th[0] + self.initial_th[1]) - self.goal_coord[0], self.l1 * np.sin(self.initial_th[0]) + self.l2 * np.sin(self.initial_th[0] + self.initial_th[1]) - self.goal_coord[1], ]) self.viewer, self.state, self.h = None, self.initial_state, 0 self.nsamples = 0 @jax.jit def f(x, u): self.nsamples += 1 m1, m2, l1, l2, g = self.m1, self.m2, self.l1, self.l2, self.g th1, th2, dth1, dth2, Dx, Dy = x t1, t2 = u a11 = (m1 + m2) * l1**2 + m2 * l2**2 + 2 * m2 * l1 * l2 * np.cos(th2) a12 = m2 * l2**2 + m2 * l1 * l2 * np.cos(th2) a22 = m2 * l2**2 b1 = (t1 + m2 * l1 * l2 * (2 * dth1 + dth2) * dth2 * np.sin(th2) - m2 * l2 * g * np.sin(th1 + th2) - (m1 + m2) * l1 * g * np.sin(th1)) b2 = (t2 - m2 * l1 * l2 * dth1**2 * np.sin(th2) - m2 * l2 * g * np.sin(th1 + th2)) A, b = np.array([[a11, a12], [a12, a22]]), np.array([b1, b2]) ddth1, ddth2 = np.linalg.inv(A) @ b th1, th2 = th1 + dth1 * self.dt, th2 + dth2 * self.dt dth1, dth2 = dth1 + ddth1 * self.dt, dth2 + ddth2 * self.dt Dx, Dy = ( l1 * np.cos(th1) + l2 * np.cos(th1 + th2) - self.goal_coord[0], l1 * np.sin(th1) + l2 * np.sin(th1 + th2) - self.goal_coord[1], ) return np.array([th1, th2, dth1, dth2, Dx, Dy]) @jax.jit def c(x, u): # coord = np.array( # [ # self.l1 * np.cos(x[0]) + self.l2 * np.cos(x[0] + x[1]), # self.l1 * np.sin(x[0]) + self.l2 * np.sin(x[0] + x[1]), # ] # ) return 0.1 * u @ u + x[4:] @ x[4:] self.f, self.f_x, self.f_u = ( f, jax.jit(jax.jacfwd(f, argnums=0)), jax.jit(jax.jacfwd(f, argnums=1)), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.jit(jax.grad(c, argnums=0)), jax.jit(jax.grad(c, argnums=1)), jax.jit(jax.hessian(c, argnums=0)), jax.jit(jax.hessian(c, argnums=1)), )
def dhxdx(self, x): return jacfwd(self.hx)(x)
cost = integral x.T*Q*x + u.T*R*u """ #ref Bertsekas, p.151 #first, try to solve the ricatti equation X = np.matrix(scipy.linalg.solve_continuous_are(A, B, Q, R)) #compute the LQR gain K = np.matrix(scipy.linalg.inv(R) * (B.T * X)) eigVals, eigVecs = scipy.linalg.eig(A - B * K) return K, X, eigVals jac_A = jax.jacfwd(jax_dynamics, argnums=0) jac_B = jax.jacfwd(jax_dynamics, argnums=1) data_folder = '../data/pendulum/' test_data = data_loader.load_test_dataset(1, 5, '../data/pendulum/', sp=1, obs_f=None) # data_type: seen or unseen obc, obs, paths, path_lengths = test_data fes_env = [] # list of list valid_env = [] time_env = [] time_total = []
def _test_primitive(self, primitive: Optional[Primitive], shapes, dtype, params): xs = _get_inputs(shapes, dtype) n = len(xs) eqn, f = _get_f_and_eqn(params, primitive, *xs) out = f(*xs) cts_in = ShapedArray(out.shape, out.dtype) argnums = tuple(range(n)) js_fwd = jax.jacfwd(f, argnums)(*xs) js_rev = jax.jacrev(f, argnums)(*xs) for idx in range(n): if primitive == lax.conv_general_dilated_p and idx == 0: raise absltest.SkipTest( 'Jacobian of CNN wrt inputs not implemented.') if primitive == lax.div_p and idx == 1: raise absltest.SkipTest( 'Division is linear only in the first arg.') invals = _get_invals(idx, *xs) j_fwd, j_rev = js_fwd[idx], js_rev[idx] if primitive in rules.JACOBIAN_RULES: j_rule = rules.JACOBIAN_RULES[primitive](eqn, idx, invals, cts_in) else: warnings.warn( f'Jacobian rule for {primitive} at position {idx} not ' f'found.') j_rule = None with self.subTest(f'Jacobian ({idx})'): self._compare_jacobians(j_fwd, j_rev, j_rule, primitive) structure = rules.STRUCTURE_RULES[primitive](eqn, idx, invals, cts_in) j = j_fwd if j_rule is None else j_rule if primitive == lax.reshape_p: out_ndim = xs[0].ndim j = j.transpose( tuple(xs[0].ndim + i for i in onp.argsort(structure.in_trace)) + tuple(i for i in onp.argsort(structure.in_trace))) j = j.reshape(xs[0].shape + tuple(xs[0].shape[i] for i in onp.argsort(structure.in_trace))) else: out_ndim = out.ndim with self.subTest(f'Diagonal axes ({idx})'): for i, o in zip(structure.in_diagonal, structure.out_diagonal): self._assert_is_diagonal(j=j, axis1=out_ndim + i[idx], axis2=o, constant_diagonal=False) with self.subTest(f'Constant diagonal axes ({idx})'): for i, o in zip(structure.in_trace, structure.out_trace): self._assert_is_diagonal(j=j, axis1=out_ndim + i, axis2=o, constant_diagonal=True) with self.subTest(f'Input broadcast axes ({idx})'): for i in structure.in_broadcast: self._assert_constant(j=j, axis=i) with self.subTest(f'Output broadcast axes ({idx})'): for i in structure.out_broadcast: self._assert_constant(j=j, axis=i)
def J_inv(x): # Create Inverse Jacobian Function jacobian = jax.jacfwd(f) # Calculate the jacobian function from the provided systems with Forward Auto-differentiation J = jacobian(x) # Calculate the Jacobian at x J_inv = jnp.linalg.inv(J) # Calculate the Inverse Jacobian return jnp.asarray(J_inv) # Return Inverse Jacobian at x as a Jax Array
def _hessian( func: Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], float] ) -> Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: return jit(jacfwd(jacrev(func, argnums=0), argnums=0))
def linearize(f, s, u): """Linearize the function `f(s,u)` around `(s,u)`.""" # WRITE YOUR CODE BELOW ################################################### A, B = jax.jacfwd(f, (0,1))(s, u) ########################################################################### return A, B
def eval_tasks(mpNet1, mpNet2, test_data, folder, filename, IsInCollision, normalize_func, unnormalize_func, informer, init_informer, system, dynamics, xdot, jax_dynamics, enforce_bounds, traj_opt, step_sz, num_steps): obc, obs, paths, sgs, path_lengths, controls, costs = test_data obc = obc.astype(np.float32) obc = torch.from_numpy(obc) fes_env = [] # list of list valid_env = [] time_env = [] time_total = [] jac_A = jax.jacfwd(jax_dynamics, argnums=0) jac_B = jax.jacfwd(jax_dynamics, argnums=1) for i in range(len(paths)): time_path = [] fes_path = [] # 1 for feasible, 0 for not feasible valid_path = [] # if the feasibility is valid or not # save paths to different files, indicated by i # feasible paths for each env for j in range(len(paths[0])): time0 = time.time() time_norm = 0. fp = 0 # indicator for feasibility print ("step: i="+str(i)+" j="+str(j)) p1_ind=0 p2_ind=0 p_ind=0 if path_lengths[i][j]<2: # invalid, feasible = 0, and path count = 0 fp = 0 valid_path.append(0) if path_lengths[i][j]>=2: state_i = [] state = paths[i][j] # obtain the sequence p_start = paths[i][j][0] detail_paths = [p_start] detail_controls = [] detail_costs = [] state = [p_start] control = [] cost = [] for k in range(len(controls[i][j])): #state_i.append(len(detail_paths)-1) max_steps = int(costs[i][j][k]/step_sz) accum_cost = 0. #print('p_start:') #print(p_start) #print('data:') #print(paths[i][j][k]) # modify it because of small difference between data and actual propagation p_start = paths[i][j][k] state[-1] = paths[i][j][k] for step in range(1,max_steps+1): p_start = dynamics(p_start, controls[i][j][k], step_sz) p_start = enforce_bounds(p_start) detail_paths.append(p_start) detail_controls.append(controls[i][j]) detail_costs.append(step_sz) accum_cost += step_sz if (step % 1 == 0) or (step == max_steps): state.append(p_start) #print('control') #print(controls[i][j]) control.append(controls[i][j][k]) cost.append(accum_cost) accum_cost = 0. #print('p_start:') #print(p_start) #print('data:') #print(paths[i][j][-1]) state[-1] = paths[i][j][-1] fp = 0 valid_path.append(1) start_node = Node(paths[i][j][0]) goal_node = Node(sgs[i][j][1]) #goal_node = Node(paths[i][j][-1]) print(goal_check(goal_node, Node(sgs[i][j][1]), system)) #goal_node.S0 = np.diag([1.,1.,0,0]) #goal_node.rho0 = 1. path = [start_node, goal_node] #step_sz = DEFAULT_STEP MAX_NEURAL_REPLAN = 21 fp = plan(obs[i], obc[i], start_node, goal_node, state, informer, init_informer, system, dynamics, enforce_bounds, \ IsInCollision, traj_opt, step_sz, num_steps) time1 = time.time() - time0 time1 -= time_norm time_path.append(time1) print('test time: %f' % (time1)) """ if fp: # only for successful paths # goal compute the stability region path[-1].x = sgs[i][j][1] # change to real goal path[-1].S0 = np.diag([1.,1.,0.,0.]) path[-1].rho0 = 1.0 # reversely construct funnel lazyFunnel(path[0], path[-1], xdot, enforce_bounds, jac_A, jac_B, traj_opt, system=system, step_sz=step_sz) fig = plt.figure() ax = fig.add_subplot(111) # after plan, generate the trajectory, and check if it is within the region xs = plot_trajectory(ax, path[0], path[-1], dynamics, enforce_bounds, collision_check, step_sz) params = {} params['obs_w'] = 6. params['obs_h'] = 6. params['integration_step'] = step_sz fig = plt.figure() ax = fig.add_subplot(111) animator = AcrobotVisualizer(Acrobot(), params) animation_acrobot(fig, ax, animator, xs, obs_i) plt.waitforbuttonpress() """ # write the path #print('planned path:') #print(path) #path = [p.numpy() for p in path] #path = np.array(path) #np.savetxt('path_%d.txt' % (j), path, fmt='%f') fes_path.append(fp) print('env %d accuracy up to now: %f' % (i, (float(np.sum(fes_path))/ np.sum(valid_path)))) time_env.append(time_path) time_total += time_path print('average test time up to now: %f' % (np.mean(time_total))) fes_env.append(fes_path) valid_env.append(valid_path) print('accuracy up to now: %f' % (float(np.sum(fes_env)) / np.sum(valid_env))) #if filename is not None: # pickle.dump(time_env, open(filename, "wb" )) # #print(fp/tp) return np.array(fes_env), np.array(valid_env)
def LS(): A = jacfwd(L, 0)(zXi, *x) B = -L(zXi, *x) xi = np.linalg.lstsq(A, B, rcond=None)[0] return xi
def state_dynamics_jac_state(self, state, control, *params): return jacfwd(self.state_dynamics)(state, control, *params)
def _value_and_grad(f, x, forward_mode_differentiation=False): if forward_mode_differentiation: return f(x), jacfwd(f)(x) else: return value_and_grad(f)(x)
def solve_direct(self, states, controls, T, homotopy, boundaries): # sanity assert states.shape[0] == controls.shape[0] assert states.shape[1] == self.state_dim assert controls.shape[1] == self.control_dim # system parameters params = self.params.values() # number of collocation nodes n = states.shape[0] # decision vector bounds @jit def get_bounds(): zl = np.hstack((self.state_lb, self.control_lb)) zl = np.tile(zl, n) zl = np.hstack(([0.0], zl)) zu = np.hstack((self.state_ub, self.control_ub)) zu = np.tile(zu, n) zu = np.hstack(([np.inf], zu)) return zl, zu # decision vector maker @jit def flatten(states, controls, T): z = np.hstack((states, controls)).flatten() z = np.hstack(([T], z)) return z # decsision vector translator @jit def unflatten(z): T = z[0] z = z[1:].reshape(n, self.state_dim + self.control_dim) states = z[:, :self.state_dim] controls = z[:, self.state_dim:] return states, controls, T # fitness vector print('Compiling fitness...') @jit def fitness(z): # translate decision vector states, controls, T = unflatten(z) # time grid n = states.shape[0] times = np.linspace(0, T, n) # objective L = vmap(lambda state, control: self.lagrangian( state, control, homotopy, *params)) L = L(states, controls) J = np.trapz(L, dx=T / (n - 1)) # Lagrangian state dynamics constraints, and boundary constraints # e0 = self.collocate_lagrangian(states, controls, times, costs, homotopy, *params) e1 = self.collocate_state(states, controls, times, *params) e2, e3 = boundaries(states[0, :], states[-1, :]) e = np.hstack((e1.flatten(), e2, e3))**2 # fitness vector return np.hstack((J, e)) # z = flatten(states, controls, T) # fitness(z) # sparse Jacobian print('Compiling Jacobian and its sparsity...') gradient = jit(jacfwd(fitness)) z = flatten(states, controls, T) sparse_id = np.vstack((np.nonzero(gradient(z)))).T sparse_gradient = jit(lambda z: gradient(z)[[*sparse_id.T]]) gradient_sparsity = jit(lambda: sparse_id) print('Jacobian has {} elements.'.format(sparse_id.shape[0])) # assign PyGMO problem methods self.fitness = fitness self.gradient = sparse_gradient self.gradient_sparsity = gradient_sparsity self.get_bounds = get_bounds self.get_nobj = jit(lambda: 1) nec = fitness(z).shape[0] - 1 self.get_nec = jit(lambda: nec) # plot before states, controls, T = unflatten(z) self.plot('../img/direct_before.png', states, dpi=1000) # solve NLP with IPOPT print('Solving...') prob = pg.problem(udp=self) algo = pg.ipopt() algo.set_integer_option('max_iter', 1000) algo = pg.algorithm(algo) algo.set_verbosity(1) pop = pg.population(prob=prob, size=0) pop.push_back(z) pop = algo.evolve(pop) # save and plot solution z = pop.champion_x np.save('decision.npy', z) states, controls, T = unflatten(z) self.plot('../img/direct_after.png', states, dpi=1000)
def test_complex_input_jacfwd_raises_error(self): self.assertRaises(TypeError, lambda: jacfwd(lambda x: np.sin(x))(1 + 2j))
def jac_GRU(GRU_params, hidden, t, X): # gradGRU = grad(GRU_forward, argnums=(3)) # <wrt H jacGRU = jax.jacfwd(GRU_forward, argnums=(3)) DGRU = jacGRU(GRU_params, hidden, t, X) return DGRU
def initialize(opt: LevenbergMaquardtBayes, obj: RBObjective, x, n): errors = obj.errors jacobian = jax.jacfwd(errors) I = np.eye(n, dtype=np.float32) # Bayesian regularization progress parameter τ = np.int32(1) # Bayesian regularization hyperparameters α, β = np.float32(np.minimum(τ / x.size, 1) / x.size), np.float32(1) # Levenberg-Maquardt parameters μi, μs, μmin, μmax = opt # def differentiate(θ, e, y): J = np.squeeze(jacobian(θ, x, y)) H = β * J.T @ J + α * I Je = β * J.T @ e + α * θ.T return H, Je # def _lm_cond(G, state): return (G <= state.G) & (state.μ <= μmax) # def _lm_update(θ, H, Je, y, Λ, state): α, β = Λ p = θ - solve(H + state.μ * I, Je, sym_pos="sym").T e = errors(p, x, y) C = obj.cost(e) R = obj.regularizer(θ) G = np.float32(β * C + α * R) return LMState(p, e, G, C, R, state.μ * μs) # def _bl_update(H, C, R, state): G, (α, _), μ, τ = state tr_inv_H = np.trace(solve(H, I, sym_pos="sym")) γ = n - α * tr_inv_H α = np.float32(n / (2 * R + tr_inv_H)) β = np.float32((x.shape[0] - γ) / (2 * C)) return G, (α, β), μ, τ # def _bl_restart(G, state): _, _, _, τ = state α = np.float32(np.minimum(τ / x.size, 1) / x.size) β = np.float32(1) return G, (α, β), μi, τ + 1 # def init(θ, data): y = data[1] e = errors(θ, x, y) i = np.int32(0) # Iterations counter k = np.int32(1) # Convergence counter C = obj.cost(e) R = obj.regularizer(θ) G = np.float32(β * C + α * R) # Objective function return LMBTrainingState(θ, e, G, C, R, (α, β), μi, τ, i, k) # def condition(state): return (state.k < 4) & (state.i < obj.max_iters) # def update(data, state): y = data[1] H, Je = differentiate(state.θ, state.e, y) # Inner Levenberg-Maquardt update lm_state = LMState(state.θ, state.e, state.G, state.C, state.R, state.μ) lm_cond = partial(_lm_cond, state.G) lm_update = partial(_lm_update, state.θ, H, Je, y, state.Λ) θ, e, G, C, R, μ = while_loop( lm_cond, lm_update, lm_state ) μ = np.where(μ < μmax, μ / μs, μ) μ = np.where(μmin < μ, μ, μmin) # Bayesian hyperparameter learning bl_state = (G, state.Λ, μ, state.τ) bl_update = partial(_bl_update, H, C, R) bl_restart = partial(_bl_restart, state.G) G, Λ, μ, τ = cond( G > state.G, bl_state, bl_restart, bl_state, bl_update ) k = np.where(G >= state.G, state.k + 1, np.int32(1)) return LMBTrainingState(θ, e, G, C, R, Λ, μ, τ, state.i + 1, k) # return Optimizer(init, condition, update)
def partial_derivative(molecule, basis_name, method, order, address): """ Computes one particular nth-order partial derivative of the energy of an electronic structure method w.r.t. a set of cartesian coordinates. If you have N cartesian coordinates in your molecule, the nuclear derivative tensor is N x N x N ... however many orders of differentiation. This function computes one element of that tensor, depending on the address of the derivative you supply. If you have 9 cartesian coordinates x1,y1,z1,x2,y2,z2,x3,y3,z3 and you want the quartic derivative d^4E/dx1dy2(dz3)^2 the 'address' of this derivative in the quartic derivative tensor would be (0, 4, 8, 8). Note that this is the same derivative as, say, (4, 8, 0, 8), or any other permutation of that tuple. Also note this is dependent upon the order in which you supply the cartesian coordinates in the molecule object, because that will determine the indices of the coordinates. Parameters ---------- Call an energy method on a molecule and basis set. Parameters ---------- molecule : psi4.Molecule A Psi4 Molecule object containing geometry, charge, multiplicity in a multiline string. Examples: molecule = psi4.geometry(''' 0 1 H 0.0 0.0 -0.55000000000 H 0.0 0.0 0.55000000000 units bohr ''') molecule = psi4.geometry(''' 0 1 O H 1 r1 H 1 r2 2 a1 r1 = 1.0 r2 = 1.0 a1 = 104.5 units ang ''') basis_name : str A string representing a Gaussian basis set available in Psi4's basis set library. method : str A string representing a quantum chemistry method supported in PsiJax method = 'scf', method = 'mp2', method = 'ccd' order : int The order of the derivative. order = 1 -> gradient ; order = 2 --> hessian ; order = 3 --> cubic ... address : tuple The index at which the desired derivative appears in the derivative tensor. Returns ------- partial_deriv : float The requested partial derivative of the energy for the given geometry and basis set. """ if len(address) != order: raise Exception( "The length of the index coordinates given by 'address' arguments should be the same as the order of differentiation" ) geom = jnp.asarray(np.asarray(molecule.geometry())) geom_list = np.asarray(molecule.geometry()).reshape(-1).tolist() mult = molecule.multiplicity() charge = molecule.molecular_charge() nuclear_charges = jnp.asarray( [molecule.charge(i) for i in range(geom.shape[0])]) # Get number of basis functions basis_set = psi4.core.BasisSet.build(molecule, 'BASIS', basis_name, puream=0) nbf = basis_set.nbf() # Save xyz file, get path xyz_file_name = "geom.xyz" molecule.save_xyz_file(xyz_file_name, True) xyz_path = os.path.abspath(os.getcwd()) + "/" + xyz_file_name #basis_dict = build_basis_set(molecule, basis_name kwargs = { "basis_name": basis_name, "xyz_path": xyz_path, "nuclear_charges": nuclear_charges, "charge": charge, "order": order } #TODO TODO TODO: support internal coordinate wrapper function. # This will take in internal coordinates, transform them into cartesians, and then compute integrals, energy # JAX will then collect the internal coordinate partial derivative instead. # If integrals already exist in the working directory and they are correct shape, reuse them. # TODO Can make this safer by including info HDF5 file with rounded geometry, atom labels, etc. if ((os.path.exists("eri_derivs.h5") and os.path.exists("oei_derivs.h5"))): print( "Found currently existing integral derivatives in your working directory. Trying to use them." ) oeifile = h5py.File('oei_derivs.h5', 'r') erifile = h5py.File('eri_derivs.h5', 'r') # Check if there are `deriv_order` datatsets in the eri file correct_deriv_order = len(erifile) == order # Check nbf dimension of integral arrays sample_dataset_name = list(oeifile.keys())[0] correct_nbf = oeifile[sample_dataset_name].shape[0] == nbf oeifile.close() erifile.close() if correct_deriv_order and correct_nbf: print( "Integral derivatives appear to be correct. Avoiding recomputation." ) else: print( "Integral derivatives dimensions do not match requested derivative order and/or basis set. Recomputing integral derivatives" ) if os.path.exists("eri_derivs.h5"): print("Deleting two electron integral derivatives...") os.remove("eri_derivs.h5") if os.path.exists("oei_derivs.h5"): print("Deleting one electron integral derivatives...") os.remove("oei_derivs.h5") libint_initialize(xyz_path, basis_name, order) libint_finalize() elif ((os.path.exists("eri_partials.h5") and os.path.exists("oei_partials.h5"))): print( "Found currently existing partial derivatives in working directory. I hope you know what you are doing!" ) else: pass #libint_initialize(xyz_path, basis_name, order) #libint_finalize() # Wrap energy functions with unpacked geometric coordinates as single arguments, so we can differentiate w.r.t. single coords if method == 'scf' or method == 'hf' or method == 'rhf': def partial_wrapper(*args, **kwargs): geom = jnp.asarray(args) basis_name = kwargs['basis_name'] xyz_path = kwargs['xyz_path'] nuclear_charges = kwargs['nuclear_charges'] charge = kwargs['charge'] order = kwargs['order'] E_scf = restricted_hartree_fock(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order, return_aux_data=False) return E_scf elif method == 'mp2': def partial_wrapper(*args, **kwargs): geom = jnp.asarray(args) basis_name = kwargs['basis_name'] xyz_path = kwargs['xyz_path'] nuclear_charges = kwargs['nuclear_charges'] charge = kwargs['charge'] order = kwargs['order'] E_mp2 = restricted_mp2(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order) return E_mp2 elif method == 'ccsd': def partial_wrapper(*args, **kwargs): geom = jnp.asarray(args) basis_name = kwargs['basis_name'] xyz_path = kwargs['xyz_path'] nuclear_charges = kwargs['nuclear_charges'] charge = kwargs['charge'] order = kwargs['order'] E_ccsd = rccsd(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order) return E_ccsd elif method == 'ccsd(t)': def partial_wrapper(*args, **kwargs): geom = jnp.asarray(args) basis_name = kwargs['basis_name'] xyz_path = kwargs['xyz_path'] nuclear_charges = kwargs['nuclear_charges'] charge = kwargs['charge'] order = kwargs['order'] E_ccsd_t = rccsd_t(geom, basis_name, xyz_path, nuclear_charges, charge, deriv_order=order) return E_ccsd_t else: raise Exception("Error: Method {} not supported.".format(method)) if order == 1: i = address[0] partial_deriv = jacfwd(partial_wrapper, i)(*geom_list, **kwargs) elif order == 2: i, j = address[0], address[1] partial_deriv = jacfwd(jacfwd(partial_wrapper, i), j)(*geom_list, **kwargs) elif order == 3: i, j, k = address[0], address[1], address[2] partial_deriv = jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k)(*geom_list, **kwargs) elif order == 4: i, j, k, l = address[0], address[1], address[2], address[3] partial_deriv = jacfwd( jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l)(*geom_list, **kwargs) elif order == 5: i, j, k, l, m = address[0], address[1], address[2], address[ 3], address[4] partial_deriv = jacfwd( jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m)(*geom_list, **kwargs) elif order == 6: i, j, k, l, m, n = address[0], address[1], address[2], address[ 3], address[4], address[5] partial_deriv = jacfwd( jacfwd(jacfwd(jacfwd(jacfwd(jacfwd(partial_wrapper, i), j), k), l), m), n)(*geom_list, **kwargs) else: print( "Error: Order {} partial derivatives are not exposed to the API.". format(order)) return partial_deriv
def body_fn(state): i, key, _, _ = state key, subkey = random.split(key) if radius is None or prototype_params is None: # XXX: we don't want to apply enum to draw latent samples model_ = model if enum: from numpyro.contrib.funsor import enum as enum_handler if isinstance(model, substitute) and isinstance( model.fn, enum_handler): model_ = substitute(model.fn.fn, data=model.data) elif isinstance(model, enum_handler): model_ = model.fn # Wrap model in a `substitute` handler to initialize from `init_loc_fn`. seeded_model = substitute(seed(model_, subkey), substitute_fn=init_strategy) model_trace = trace(seeded_model).get_trace( *model_args, **model_kwargs) constrained_values, inv_transforms = {}, {} for k, v in model_trace.items(): if (v["type"] == "sample" and not v["is_observed"] and not v["fn"].support.is_discrete): constrained_values[k] = v["value"] with helpful_support_errors(v): inv_transforms[k] = biject_to(v["fn"].support) params = transform_fn( inv_transforms, {k: v for k, v in constrained_values.items()}, invert=True, ) else: # this branch doesn't require tracing the model params = {} for k, v in prototype_params.items(): if k in init_values: params[k] = init_values[k] else: params[k] = random.uniform(subkey, jnp.shape(v), minval=-radius, maxval=radius) key, subkey = random.split(key) potential_fn = partial(potential_energy, model, model_args, model_kwargs, enum=enum) if validate_grad: if forward_mode_differentiation: pe = potential_fn(params) z_grad = jacfwd(potential_fn)(params) else: pe, z_grad = value_and_grad(potential_fn)(params) z_grad_flat = ravel_pytree(z_grad)[0] is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat)) else: pe = potential_fn(params) is_valid = jnp.isfinite(pe) z_grad = None return i + 1, key, (params, pe, z_grad), is_valid
def __init__(self, g=9.81): self.initialized = False self.dt=.05 self.m = 0.1 # kg self.L = 0.2 # m self.I = 0.004 # inertia, kg*m^2 self.g = g self.hover_input = np.array([self.m*self.g/2., self.m*self.g/2.]) self.viewer = None self.action_size = 2 self.state_size = 6 self.wind_force = 0.2 self.initial_state = np.array([1.0, 1.0, 0.0, 0.0, 0.0, 0.0]) self.goal_state = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) self.h = 100 @jax.jit def _dynamics(inp): x,u = inp state = x x, y, th, xdot, ydot, thdot = state u1, u2 = u g = self.g m = self.m L = self.L I = self.I dt = self.dt xddot = -(u1+u2)*np.sin(th)/m # xddot yddot = (u1+u2)*np.cos(th)/m - g # yddot thddot = L*(u2 - u1)/I # thetaddot state_dot = np.array([xdot, ydot, thdot, xddot, yddot, thddot]) new_state = state + state_dot*dt return new_state @jax.jit def _wind_field(x,y): return [self.wind_force*x, self.wind_force*y] self._wind_field = _wind_field @jax.jit def _dynamics_real(inp): x,u = inp state = x x, y, th, xdot, ydot, thdot = state u1, u2 = u g = self.g m = self.m L = self.L I = self.I dt = self.dt wind = self._wind_field(x,y) xddot = -(u1+u2)*np.sin(th)/m + wind[0]/m # xddot yddot = (u1+u2)*np.cos(th)/m - g + wind[1]/m # yddot thddot = L*(u2 - u1)/I # thetaddot state_dot = np.array([xdot, ydot, thdot, xddot, yddot, thddot]) new_state = state + state_dot*dt return new_state self._dynamics = _dynamics self._dynamics_real = _dynamics_real self._dynamics_der = jax.jit(jax.jacfwd(_dynamics)) self._dynamics_real_der = jax.jit(jax.jacfwd(_dynamics_real)) def _costval(x, u, i): if i==self.h: return 100*np.linalg.norm(x - self.goal_state)**2 else: return np.linalg.norm(x - self.goal_state)**2 + 0.1*np.linalg.norm(u - self.hover_input)**2 def _costgrad(x,u,i): if i==self.h: return [200*(x - self.goal_state), np.zeros((self.action_size, self.action_size)), 200*np.eye(self.state_size), np.zeros((self.action_size, self.state_size)), np.zeros((self.action_size, self.action_size))] else: return [2*(x-self.goal_state), 0.2*(u - self.hover_input), 2*np.eye(self.state_size), np.zeros((self.action_size,self.state_size)), 0.2*np.eye(self.action_size)] self._cost = _costval self._costgrad = _costgrad
phi0, pt0 = mom.phi0pt0(pos, q, B) # The flight length in the transverse plane, measured from the point of # the helix closeset to the z-axis length = (phi - phi0) * pt / qalph return ( Helix( d0=(pt0 - pt) / qalph, phi0=phi0, omega=qalph / pt, z0=pos.z - length * mom.pz / pt, tanl=mom.pz/pt), length) position_from_helix_jacobian = jax.vmap( jax.jacfwd(position_from_helix, argnums=0)) momentum_from_helix_jacobian = jax.vmap( jax.jacfwd(momentum_from_helix, argnums=0)) def full_jacobian_from_helix(hel: Helix, length: dtype, q: int, B: float)\ -> (dtype): """ Calculates helix over (pos, mom) jacobian. Returns np.array of shape (N, 5, 6), where N is number of events """ jac_pos = position_from_helix_jacobian(hel, length, q, B) jac_mom = momentum_from_helix_jacobian(hel, length, q, B) return np.stack([ jac_pos.x.as_array,
import jax.numpy as jnp @jit def incremental_arclength_pure(d1gamma): return jnp.linalg.norm(d1gamma, axis=1) incremental_arclength_vjp = jit(lambda d1gamma, v: vjp(lambda d1g: incremental_arclength_pure(d1g), d1gamma)[1](v)[0]) @jit def kappa_pure(d1gamma, d2gamma): return jnp.linalg.norm(jnp.cross(d1gamma, d2gamma), axis=1)/jnp.linalg.norm(d1gamma, axis=1)**3 kappavjp0 = jit(lambda d1gamma, d2gamma, v: vjp(lambda d1g: kappa_pure(d1g, d2gamma), d1gamma)[1](v)[0]) kappavjp1 = jit(lambda d1gamma, d2gamma, v: vjp(lambda d2g: kappa_pure(d1gamma, d2g), d2gamma)[1](v)[0]) kappagrad0 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d1g: kappa_pure(d1g, d2gamma))(d1gamma)) kappagrad1 = jit(lambda d1gamma, d2gamma: jacfwd(lambda d2g: kappa_pure(d1gamma, d2g))(d2gamma)) @jit def torsion_pure(d1gamma, d2gamma, d3gamma): return jnp.sum(jnp.cross(d1gamma, d2gamma, axis=1) * d3gamma, axis=1) / jnp.sum(jnp.cross(d1gamma, d2gamma, axis=1)**2, axis=1) torsionvjp0 = jit(lambda d1gamma, d2gamma, d3gamma, v: vjp(lambda d1g: torsion_pure(d1g, d2gamma, d3gamma), d1gamma)[1](v)[0]) torsionvjp1 = jit(lambda d1gamma, d2gamma, d3gamma, v: vjp(lambda d2g: torsion_pure(d1gamma, d2g, d3gamma), d2gamma)[1](v)[0]) torsionvjp2 = jit(lambda d1gamma, d2gamma, d3gamma, v: vjp(lambda d3g: torsion_pure(d1gamma, d2gamma, d3g), d3gamma)[1](v)[0]) class Curve(): def __init__(self): self.dependencies = []
def single_jac(x_new): # x_new = x_new.reshape(1, -1) jac = jax.jacfwd(fun, (0))(x_new, **fun_kwargs) print("single jac") print(jac.shape) return jac
def threebody_jax(tmax=17.0652165601579625588917206249): r"""Initial value problem (IVP) based on a three-body problem. Let the initial conditions be :math:`y = (y_1, y_2, \dot{y}_1, \dot{y}_2)^T`. This function implements the second-order three-body problem as a system of first-order ODEs, which is defined as follows: [1]_ .. math:: f(t, y) = \begin{pmatrix} \dot{y_1} \\ \dot{y_2} \\ y_1 + 2 \dot{y}_2 - \frac{(1 - \mu) (y_1 + \mu)}{d_1} - \frac{\mu (y_1 - (1 - \mu))}{d_2} \\ y_2 - 2 \dot{y}_1 - \frac{(1 - \mu) y_2}{d_1} - \frac{\mu y_2}{d_2} \end{pmatrix} with .. math:: d_1 &= ((y_1 + \mu)^2 + y_2^2)^{\frac{3}{2}} \\ d_2 &= ((y_1 - (1 - \mu))^2 + y_2^2)^{\frac{3}{2}} and a constant parameter :math:`\mu = 0.012277471` denoting the standardized moon mass. Parameters ---------- tmax Final time. Returns ------- IVP IVP object describing a three-body problem IVP with the prescribed configuration. References ---------- .. [1] Hairer, E., Norsett, S. and Wanner, G.. Solving Ordinary Differential Equations I. Springer Series in Computational Mathematics, 1993. """ try: import jax import jax.numpy as jnp from jax.config import config config.update("jax_enable_x64", True) except ImportError as err: raise ImportError(JAX_ERRORMSG) from err def threebody_rhs(Y): # defining the ODE: # assume Y = [y1,y2,y1',y2'] mu = 0.012277471 # a constant (standardised moon mass) mp = 1 - mu D1 = ((Y[0] + mu)**2 + Y[1]**2)**(3 / 2) D2 = ((Y[0] - mp)**2 + Y[1]**2)**(3 / 2) y1p = Y[0] + 2 * Y[3] - mp * (Y[0] + mu) / D1 - mu * (Y[0] - mp) / D2 y2p = Y[1] - 2 * Y[2] - mp * Y[1] / D1 - mu * Y[1] / D2 return jnp.array([Y[2], Y[3], y1p, y2p]) df = jax.jit(jax.jacfwd(threebody_rhs)) ddf = jax.jit(jax.jacrev(df)) def rhs(t, y): return threebody_rhs(Y=y) def jac(t, y): return df(y) def hess(t, y): return ddf(y) y0 = np.array([0.994, 0, 0, -2.00158510637908252240537862224]) t0 = 0.0 return InitialValueProblem(f=rhs, t0=t0, tmax=tmax, y0=y0, df=jac, ddf=hess)
# load hmc warmup inv_metric = pickle.load(open('stan_traces/inv_metric.pkl', 'rb')) stepsize = pickle.load(open('stan_traces/step_size.pkl', 'rb')) last_pos = pickle.load(open('stan_traces/last_pos.pkl', 'rb')) # define MPC cost, gradient and hessian function cost = jit( log_barrier_cosine_cost, static_argnums=(11, 12, 13, 14, 15)) # static argnums means it will recompile if N changes gradient = jit( grad(log_barrier_cosine_cost, argnums=0), static_argnums=(11, 12, 13, 14, 15) ) # get compiled function to return gradients with respect to z (uc, s) hessian = jit(jacfwd(jacrev(log_barrier_cosine_cost, argnums=0)), static_argnums=(11, 12, 13, 14, 15)) mu = 1e4 gamma = 1 delta = 0.05 max_iter = 5000 # declare some variables for storing the ongoing resutls xt_est_save = np.zeros((Ns, Nx, T)) theta_est_save = np.zeros((Ns, 6, T)) q_est_save = np.zeros((Ns, 4, T)) r_est_save = np.zeros((Ns, 3, T)) uc_save = np.zeros((1, Nh, T)) mpc_result_save = [] hmc_traces_save = []