def backward_pass(self, F_x, F_u, L_x, L_u, L_xx, L_ux, L_uu, F_xx=None, F_ux=None, F_uu=None): """Computes the feedforward and feedback gains k and K. Args: F_x: Jacobian of state path w.r.t. x [N, state_size, state_size]. F_u: Jacobian of state path w.r.t. u [N, state_size, action_size]. L_x: Jacobian of cost path w.r.t. x [N+1, state_size]. L_u: Jacobian of cost path w.r.t. u [N, action_size]. L_xx: Hessian of cost path w.r.t. x, x [N+1, state_size, state_size]. L_ux: Hessian of cost path w.r.t. u, x [N, action_size, state_size]. L_uu: Hessian of cost path w.r.t. u, u [N, action_size, action_size]. F_xx: Hessian of state path w.r.t. x, x if Hessians are used [N, state_size, state_size, state_size]. F_ux: Hessian of state path w.r.t. u, x if Hessians are used [N, state_size, action_size, state_size]. F_uu: Hessian of state path w.r.t. u, u if Hessians are used [N, state_size, action_size, action_size]. Returns: Tuple of k: feedforward gains [N, action_size]. K: feedback gains [N, action_size, state_size]. """ V_x = L_x[-1] V_xx = L_xx[-1] k = np.empty_like(self._k) K = np.empty_like(self._K) for i in range(self.N - 1, -1, -1): # if self._use_hessians: # Q_x, Q_u, Q_xx, Q_ux, Q_uu = self._Q(F_x[i], F_u[i], L_x[i], # L_u[i], L_xx[i], L_ux[i], # L_uu[i], V_x, V_xx, # F_xx[i], F_ux[i], F_uu[i]) # else: Q_x, Q_u, Q_xx, Q_ux, Q_uu = self.Q(F_x[i], F_u[i], L_x[i], L_u[i], L_xx[i], L_ux[i], L_uu[i], V_x, V_xx) # Eq (6). k = jax.ops.index_update(k, i, -np.linalg.solve(Q_uu, Q_u)) K = jax.ops.index_update(K, i, -np.linalg.solve(Q_uu, Q_ux)) # inv_Q_uu = np.linalg.pinv(Q_uu) # print("pseudoInv", inv_Q_uu) # k[i] = -inv_Q_uu.dot(Q_u) # K[i] = -inv_Q_uu.dot(Q_ux) # Eq (11b). V_x = Q_x + K[i].T.dot(Q_uu).dot(k[i]) V_x += K[i].T.dot(Q_u) + Q_ux.T.dot(k[i]) # Eq (11c). V_xx = Q_xx + K[i].T.dot(Q_uu).dot(K[i]) V_xx += K[i].T.dot(Q_ux) + Q_ux.T.dot(K[i]) V_xx = 0.5 * (V_xx + V_xx.T) # To maintain symmetry. return np.array(k), np.array(K)
def forward_pass(x_trj, u_trj, k_trj, K_trj): u_trj = np.arcsin(np.sin(u_trj)) x_trj_new = np.empty_like(x_trj) x_trj_new = jax.ops.index_update(x_trj_new, jax.ops.index[0], x_trj[0]) u_trj_new = np.empty_like(u_trj) x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new = lax.fori_loop( 0, TIME_STEPS-1, forward_pass_looper, [x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new] ) return x_trj_new, u_trj_new
def backward_pass(x_trj, u_trj, regu, target): k_trj = np.empty_like(u_trj) K_trj = np.empty((TIME_STEPS-1, N_U, N_X)) expected_cost_redu = 0. V_x, V_xx = derivative_final(x_trj[-1], target) V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target = lax.fori_loop( 0, TIME_STEPS-1, backward_pass_looper, [V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target] ) return k_trj, K_trj, expected_cost_redu
def shift(tsr, fill=np.nan): """Rolls tensor backwards by one. Shifts one-dimensional tensor to the left, discarding the first element and filling the empty slot at the end with a new value. Args: tsr: One-dimensional tensor. fill: Value to add to tensor. Returns: Tensor with same shape as `tsr` but whose elements are shifted to the left by one with a new element at the end. """ out = jo.index_update(np.empty_like(tsr), -1, fill) return jo.index_update(out, jo.index[:-1], tsr[1:])
def _cholesky_solve(chol, rhs, name=None): """Scipy cho_solve does not broadcast, so we must do so explicitly.""" del name if JAX_MODE: # But JAX uses XLA, which can do a batched solve. chol = chol + np.zeros(rhs.shape[:-2] + (1, 1), dtype=chol.dtype) rhs = rhs + np.zeros(chol.shape[:-2] + (1, 1), dtype=rhs.dtype) return scipy_linalg.cho_solve((chol, True), rhs) try: bcast = onp.broadcast(chol[..., :1], rhs) except ValueError as e: raise ValueError( 'Error with inputs shaped `chol`={}, rhs={}:\n{}'.format( chol.shape, rhs.shape, str(e))) dim = chol.shape[-1] chol = onp.broadcast_to(chol, bcast.shape[:-1] + (dim, )) rhs = onp.broadcast_to(rhs, bcast.shape) nbatch = int(np.prod(chol.shape[:-2])) flat_chol = chol.reshape(nbatch, dim, dim) flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1]) result = np.empty_like(flat_rhs) if np.size(result): for i, (ch, rh) in enumerate(zip(flat_chol, flat_rhs)): result[i] = scipy_linalg.cho_solve((ch, True), rh) return result.reshape(*rhs.shape)
def none(val): if isinstance(val, jnp.ndarray): return jnp.nan * jnp.empty_like(val) return float('nan')
def runLaneEmden(N, m, basis, k, xf): ## user defined parameters: ************************************************************************ # N - number of discretization points # m - number of basis function terms # basis - basis function type # k - specific problem type, k >=0 (analytical solution known for k = 0, 1, and 5) ## problem initial conditions: ***************************************************************** xspan = [0., xf] # problem domain range [x0, xf], where x₀ > 0 y0 = 1. # y(x0) = 1 y0p = 0. # y'(x0) = 0 nC = 2 # number of constraints ## construct univariate tfc class: ************************************************************* tfc = utfc(N, nC, int(m), basis=basis, x0=xspan[0], xf=xspan[1]) x = tfc.x H = tfc.H dH = tfc.dH H0 = H(x[0:1]) H0p = dH(x[0:1]) ## define tfc constrained expression and derivatives: ****************************************** # switching function phi1 = lambda x: np.ones_like(x) phi2 = lambda x: x # tfc constrained expression y = lambda x, xi: np.dot(H(x), xi) + phi1(x) * (y0 - np.dot( H0, xi)) + phi2(x) * (y0p - np.dot(H0p, xi)) yp = egrad(y) ypp = egrad(yp) ## define the loss function: ******************************************************************* L = lambda xi: x * ypp(x, xi) + 2. * yp(x, xi) + x * y(x, xi)**k ## solve the problem via nonlinear least-squares *********************************************** xi = np.zeros(H(x).shape[1]) # if k==0 or k==1, the problem is linear if k == 0 or k == 1: xi, time = LS(xi, L, timer=True) iter = 1 else: xi, iter, time = NLLS(xi, L, timer=True) ## compute the error (if k = 0, 1, or 5): ****************************************************** if k == 0: ytrue = 1. - 1. / 6. * x**2 elif k == 1: ytrue = onp.ones_like(x) ytrue[1:] = np.sin(x[1:]) / x[1:] elif k == 5: ytrue = (1. + x**2 / 3)**(-1 / 2) else: ytrue = np.empty_like(x) err = np.abs(y(x, xi) - ytrue) ## compute the residual of the loss vector: **************************************************** res = np.abs(L(xi)) return x, y(x, xi), err, res
def f(x): arr = jnp.empty_like(x) arr.fill(x) return arr
def shallow_water_step(state, is_first_step): """Perform one step of the shallow-water model. Returns modified model state. """ token = jax.lax.create_token() h, u, v, dh, du, dv = state hc = jnp.pad(h[1:-1, 1:-1], 1, "edge") hc, token = enforce_boundaries(hc, "h", token) fe = jnp.empty_like(u) fn = jnp.empty_like(u) fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1]) fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1]) fe, token = enforce_boundaries(fe, "u", token) fn, token = enforce_boundaries(fn, "v", token) dh_new = dh.at[1:-1, 1:-1].set( -(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy ) # nonlinear momentum equation q = jnp.empty_like(u) ke = jnp.empty_like(u) # planetary and relative vorticity q = q.at[1:-1, 1:-1].set( CORIOLIS_PARAM[1:-1, 1:-1] + ((v[1:-1, 2:] - v[1:-1, 1:-1]) / dx - (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy) ) # potential vorticity q = q.at[1:-1, 1:-1].mul( 1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:])) ) q, token = enforce_boundaries(q, "h", token) du_new = du.at[1:-1, 1:-1].set( -GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx + 0.5 * ( q[1:-1, 1:-1] * 0.5 * (fn[1:-1, 1:-1] + fn[1:-1, 2:]) + q[:-2, 1:-1] * 0.5 * (fn[:-2, 1:-1] + fn[:-2, 2:]) ) ) dv_new = dv.at[1:-1, 1:-1].set( -GRAVITY * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy - 0.5 * ( q[1:-1, 1:-1] * 0.5 * (fe[1:-1, 1:-1] + fe[2:, 1:-1]) + q[1:-1, :-2] * 0.5 * (fe[1:-1, :-2] + fe[2:, :-2]) ) ) ke = ke.at[1:-1, 1:-1].set( 0.5 * ( 0.5 * (u[1:-1, 1:-1] ** 2 + u[1:-1, :-2] ** 2) + 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2) ) ) ke, token = enforce_boundaries(ke, "h", token) du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx) dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy) if is_first_step: u = u.at[1:-1, 1:-1].add(dt * du_new[1:-1, 1:-1]) v = v.at[1:-1, 1:-1].add(dt * dv_new[1:-1, 1:-1]) h = h.at[1:-1, 1:-1].add(dt * dh_new[1:-1, 1:-1]) else: u = u.at[1:-1, 1:-1].add( dt * ( ADAMS_BASHFORTH_A * du_new[1:-1, 1:-1] + ADAMS_BASHFORTH_B * du[1:-1, 1:-1] ) ) v = v.at[1:-1, 1:-1].add( dt * ( ADAMS_BASHFORTH_A * dv_new[1:-1, 1:-1] + ADAMS_BASHFORTH_B * dv[1:-1, 1:-1] ) ) h = h.at[1:-1, 1:-1].add( dt * ( ADAMS_BASHFORTH_A * dh_new[1:-1, 1:-1] + ADAMS_BASHFORTH_B * dh[1:-1, 1:-1] ) ) h, token = enforce_boundaries(h, "h", token) u, token = enforce_boundaries(u, "u", token) v, token = enforce_boundaries(v, "v", token) if LATERAL_VISCOSITY > 0: # lateral friction fe = fe.at[1:-1, 1:-1].set( LATERAL_VISCOSITY * (u[1:-1, 2:] - u[1:-1, 1:-1]) / dx ) fn = fn.at[1:-1, 1:-1].set( LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy ) fe, token = enforce_boundaries(fe, "u", token) fn, token = enforce_boundaries(fn, "v", token) u = u.at[1:-1, 1:-1].add( dt * ( (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy ) ) fe = fe.at[1:-1, 1:-1].set( LATERAL_VISCOSITY * (v[1:-1, 2:] - u[1:-1, 1:-1]) / dx ) fn = fn.at[1:-1, 1:-1].set( LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy ) fe, token = enforce_boundaries(fe, "u", token) fn, token = enforce_boundaries(fn, "v", token) v = v.at[1:-1, 1:-1].add( dt * ( (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy ) ) return ModelState(h, u, v, dh_new, du_new, dv_new)
def enforce_boundaries(arr, grid, token=None): """Handle boundary exchange between processors. This is where mpi4jax comes in! """ assert grid in ("h", "u", "v") # start sending west, go clockwise send_order = ( "west", "north", "east", "south", ) # start receiving east, go clockwise recv_order = ( "east", "south", "west", "north", ) overlap_slices_send = dict( south=(1, slice(None), Ellipsis), west=(slice(None), 1, Ellipsis), north=(-2, slice(None), Ellipsis), east=(slice(None), -2, Ellipsis), ) overlap_slices_recv = dict( south=(0, slice(None), Ellipsis), west=(slice(None), 0, Ellipsis), north=(-1, slice(None), Ellipsis), east=(slice(None), -1, Ellipsis), ) proc_neighbors = { "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None, "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None, "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None, "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None, } if PERIODIC_BOUNDARY_X: if proc_idx[1] == 0: proc_neighbors["west"] = (proc_idx[0], nproc_x - 1) if proc_idx[1] == nproc_x - 1: proc_neighbors["east"] = (proc_idx[0], 0) if token is None: token = jax.lax.create_token() for send_dir, recv_dir in zip(send_order, recv_order): send_proc = proc_neighbors[send_dir] recv_proc = proc_neighbors[recv_dir] if send_proc is None and recv_proc is None: continue if send_proc is not None: send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x)) if recv_proc is not None: recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x)) recv_idx = overlap_slices_recv[recv_dir] recv_arr = jnp.empty_like(arr[recv_idx]) send_idx = overlap_slices_send[send_dir] send_arr = arr[send_idx] if send_proc is None: recv_arr, token = mpi4jax.recv( recv_arr, source=recv_proc, comm=mpi_comm, token=token ) arr = arr.at[recv_idx].set(recv_arr) elif recv_proc is None: token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token) else: recv_arr, token = mpi4jax.sendrecv( send_arr, recv_arr, source=recv_proc, dest=send_proc, comm=mpi_comm, token=token, ) arr = arr.at[recv_idx].set(recv_arr) if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1: arr = arr.at[:, -2].set(0.0) if grid == "v" and proc_idx[0] == nproc_y - 1: arr = arr.at[-2, :].set(0.0) return arr, token
def _bdf_step(state, fun, jac): # print('bdf_step', state.t, state.h) # we will try and use the old jacobian unless convergence of newton iteration # fails updated_jacobian = False # initialise step size and try to make the step, # iterate, reducing step size until error is in bounds step_accepted = False y = jnp.empty_like(state.y0) d = jnp.empty_like(state.y0) n_iter = -1 # loop until step is accepted while_state = [state, step_accepted, updated_jacobian, y, d, n_iter] def while_cond(while_state): _, step_accepted, _, _, _, _ = while_state return step_accepted == False # noqa: E712 def while_body(while_state): state, step_accepted, updated_jacobian, y, d, n_iter = while_state # solve BDF equation using y0 as starting point converged, n_iter, y, d, state = _newton_iteration(state, fun) not_converged = converged == False # noqa: E712 # newton iteration did not converge, but jacobian has already been # evaluated so reduce step size by 0.3 (as per [1]) and try again state = tree_multimap( partial(jnp.where, not_converged * updated_jacobian), _update_step_size_and_lu(state, 0.3), state, ) # if not_converged * updated_jacobian: # print('not converged, update step size by 0.3') # if not_converged * (updated_jacobian == False): # print('not converged, update jacobian') # if not converged and jacobian not updated, then update the jacobian and try # again (state, updated_jacobian) = tree_multimap( partial( jnp.where, not_converged * (updated_jacobian == False) # noqa: E712 ), (_update_jacobian(state, jac), True), (state, False + updated_jacobian), ) safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) scale_y = state.atol + state.rtol * jnp.abs(y) # combine eq 3, 4 and 6 from [1] to obtain error # Note that error = C_k * h^{k+1} y^{k+1} # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} error = state.error_const[state.order] * d error_norm = rms_norm(error / scale_y) # calculate optimal step size factor as per eq 2.46 of [2] factor = jnp.maximum( MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1)) ) # if converged * (error_norm > 1): # print('converged, but error is too large',error_norm, factor, d, scale_y) (state, step_accepted) = tree_multimap( partial(jnp.where, converged * (error_norm > 1)), # noqa: E712 (_update_step_size_and_lu(state, factor), False), (state, converged), ) return [state, step_accepted, updated_jacobian, y, d, n_iter] state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop( while_cond, while_body, while_state ) # take the accepted step n_steps = state.n_steps + 1 t = state.t + state.h # a change in order is only done after running at order k for k + 1 steps # (see page 83 of [2]) n_equal_steps = state.n_equal_steps + 1 state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps) state = tree_multimap( partial(jnp.where, n_equal_steps < state.order + 1), _prepare_next_step(state, d), _prepare_next_step_order_change(state, d, y, n_iter), ) return state