Example #1
0
	def backward_pass(self, F_x, F_u, L_x, L_u, L_xx, L_ux, L_uu, F_xx=None, F_ux=None, F_uu=None):
		"""Computes the feedforward and feedback gains k and K.

		Args:
		    F_x: Jacobian of state path w.r.t. x [N, state_size, state_size].
		    F_u: Jacobian of state path w.r.t. u [N, state_size, action_size].
		    L_x: Jacobian of cost path w.r.t. x [N+1, state_size].
		    L_u: Jacobian of cost path w.r.t. u [N, action_size].
		    L_xx: Hessian of cost path w.r.t. x, x
		        [N+1, state_size, state_size].
		    L_ux: Hessian of cost path w.r.t. u, x [N, action_size, state_size].
		    L_uu: Hessian of cost path w.r.t. u, u
		        [N, action_size, action_size].
		    F_xx: Hessian of state path w.r.t. x, x if Hessians are used
		        [N, state_size, state_size, state_size].
		    F_ux: Hessian of state path w.r.t. u, x if Hessians are used
		        [N, state_size, action_size, state_size].
		    F_uu: Hessian of state path w.r.t. u, u if Hessians are used
		        [N, state_size, action_size, action_size].

		Returns:
		    Tuple of
		        k: feedforward gains [N, action_size].
		        K: feedback gains [N, action_size, state_size].
		"""
		V_x = L_x[-1]
		V_xx = L_xx[-1]


		k = np.empty_like(self._k)
		K = np.empty_like(self._K)
		for i in range(self.N - 1, -1, -1):
			# if self._use_hessians:
			#     Q_x, Q_u, Q_xx, Q_ux, Q_uu = self._Q(F_x[i], F_u[i], L_x[i],
			#                                          L_u[i], L_xx[i], L_ux[i],
			#                                          L_uu[i], V_x, V_xx,
			#                                          F_xx[i], F_ux[i], F_uu[i])
			# else:
			Q_x, Q_u, Q_xx, Q_ux, Q_uu = self.Q(F_x[i], F_u[i], L_x[i],
			                                         L_u[i], L_xx[i], L_ux[i],
			                                         L_uu[i], V_x, V_xx)

			# Eq (6).
			k = jax.ops.index_update(k, i, -np.linalg.solve(Q_uu, Q_u))
			K = jax.ops.index_update(K, i, -np.linalg.solve(Q_uu, Q_ux))
			# inv_Q_uu = np.linalg.pinv(Q_uu)
			# print("pseudoInv", inv_Q_uu)
			# k[i] = -inv_Q_uu.dot(Q_u)
			# K[i] = -inv_Q_uu.dot(Q_ux)

			# Eq (11b).
			V_x = Q_x + K[i].T.dot(Q_uu).dot(k[i])
			V_x += K[i].T.dot(Q_u) + Q_ux.T.dot(k[i])

			# Eq (11c).
			V_xx = Q_xx + K[i].T.dot(Q_uu).dot(K[i])
			V_xx += K[i].T.dot(Q_ux) + Q_ux.T.dot(K[i])
			V_xx = 0.5 * (V_xx + V_xx.T)  # To maintain symmetry.
		return np.array(k), np.array(K)
Example #2
0
def forward_pass(x_trj, u_trj, k_trj, K_trj):
    u_trj = np.arcsin(np.sin(u_trj))
    
    x_trj_new = np.empty_like(x_trj)
    x_trj_new = jax.ops.index_update(x_trj_new, jax.ops.index[0], x_trj[0])
    u_trj_new = np.empty_like(u_trj)
    
    x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new = lax.fori_loop(
        0, TIME_STEPS-1, forward_pass_looper, [x_trj, u_trj, k_trj, K_trj, x_trj_new, u_trj_new]
    )

    return x_trj_new, u_trj_new
Example #3
0
def backward_pass(x_trj, u_trj, regu, target):
    k_trj = np.empty_like(u_trj)
    K_trj = np.empty((TIME_STEPS-1, N_U, N_X))
    expected_cost_redu = 0.
    V_x, V_xx = derivative_final(x_trj[-1], target)
     
    V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target = lax.fori_loop(
        0, TIME_STEPS-1, backward_pass_looper, [V_x, V_xx, k_trj, K_trj, x_trj, u_trj, expected_cost_redu, regu, target]
    )
        
    return k_trj, K_trj, expected_cost_redu
Example #4
0
def shift(tsr, fill=np.nan):
    """Rolls tensor backwards by one.
    
    Shifts one-dimensional tensor to the left, discarding the first
    element and filling the empty slot at the end with a new value.

    Args:
        tsr: One-dimensional tensor.
        fill: Value to add to tensor.

    Returns:
        Tensor with same shape as `tsr` but whose elements are shifted
        to the left by one with a new element at the end.
    """
    out = jo.index_update(np.empty_like(tsr), -1, fill)

    return jo.index_update(out, jo.index[:-1], tsr[1:])
Example #5
0
def _cholesky_solve(chol, rhs, name=None):
    """Scipy cho_solve does not broadcast, so we must do so explicitly."""
    del name
    if JAX_MODE:  # But JAX uses XLA, which can do a batched solve.
        chol = chol + np.zeros(rhs.shape[:-2] + (1, 1), dtype=chol.dtype)
        rhs = rhs + np.zeros(chol.shape[:-2] + (1, 1), dtype=rhs.dtype)
        return scipy_linalg.cho_solve((chol, True), rhs)
    try:
        bcast = onp.broadcast(chol[..., :1], rhs)
    except ValueError as e:
        raise ValueError(
            'Error with inputs shaped `chol`={}, rhs={}:\n{}'.format(
                chol.shape, rhs.shape, str(e)))
    dim = chol.shape[-1]
    chol = onp.broadcast_to(chol, bcast.shape[:-1] + (dim, ))
    rhs = onp.broadcast_to(rhs, bcast.shape)
    nbatch = int(np.prod(chol.shape[:-2]))
    flat_chol = chol.reshape(nbatch, dim, dim)
    flat_rhs = rhs.reshape(nbatch, dim, rhs.shape[-1])
    result = np.empty_like(flat_rhs)
    if np.size(result):
        for i, (ch, rh) in enumerate(zip(flat_chol, flat_rhs)):
            result[i] = scipy_linalg.cho_solve((ch, True), rh)
    return result.reshape(*rhs.shape)
Example #6
0
 def none(val):
     if isinstance(val, jnp.ndarray):
         return jnp.nan * jnp.empty_like(val)
     return float('nan')
Example #7
0
def runLaneEmden(N, m, basis, k, xf):
    ## user defined parameters: ************************************************************************
    # N      - number of discretization points
    # m      - number of basis function terms
    # basis  - basis function type
    # k      - specific problem type, k >=0 (analytical solution known for k = 0, 1, and 5)

    ## problem initial conditions: *****************************************************************
    xspan = [0., xf]  # problem domain range [x0, xf], where x₀ > 0
    y0 = 1.  # y(x0)  = 1
    y0p = 0.  # y'(x0) = 0
    nC = 2  # number of constraints

    ## construct univariate tfc class: *************************************************************
    tfc = utfc(N, nC, int(m), basis=basis, x0=xspan[0], xf=xspan[1])
    x = tfc.x

    H = tfc.H
    dH = tfc.dH
    H0 = H(x[0:1])
    H0p = dH(x[0:1])

    ## define tfc constrained expression and derivatives: ******************************************
    # switching function
    phi1 = lambda x: np.ones_like(x)
    phi2 = lambda x: x

    # tfc constrained expression
    y = lambda x, xi: np.dot(H(x), xi) + phi1(x) * (y0 - np.dot(
        H0, xi)) + phi2(x) * (y0p - np.dot(H0p, xi))
    yp = egrad(y)
    ypp = egrad(yp)

    ## define the loss function: *******************************************************************
    L = lambda xi: x * ypp(x, xi) + 2. * yp(x, xi) + x * y(x, xi)**k

    ## solve the problem via nonlinear least-squares ***********************************************
    xi = np.zeros(H(x).shape[1])

    # if k==0 or k==1, the problem is linear
    if k == 0 or k == 1:
        xi, time = LS(xi, L, timer=True)
        iter = 1

    else:
        xi, iter, time = NLLS(xi, L, timer=True)

    ## compute the error (if k = 0, 1, or 5): ******************************************************
    if k == 0:
        ytrue = 1. - 1. / 6. * x**2
    elif k == 1:
        ytrue = onp.ones_like(x)
        ytrue[1:] = np.sin(x[1:]) / x[1:]
    elif k == 5:
        ytrue = (1. + x**2 / 3)**(-1 / 2)
    else:
        ytrue = np.empty_like(x)

    err = np.abs(y(x, xi) - ytrue)

    ## compute the residual of the loss vector: ****************************************************
    res = np.abs(L(xi))

    return x, y(x, xi), err, res
Example #8
0
 def f(x):
     arr = jnp.empty_like(x)
     arr.fill(x)
     return arr
Example #9
0
def shallow_water_step(state, is_first_step):
    """Perform one step of the shallow-water model.

    Returns modified model state.
    """
    token = jax.lax.create_token()

    h, u, v, dh, du, dv = state

    hc = jnp.pad(h[1:-1, 1:-1], 1, "edge")
    hc, token = enforce_boundaries(hc, "h", token)

    fe = jnp.empty_like(u)
    fn = jnp.empty_like(u)

    fe = fe.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[1:-1, 2:]) * u[1:-1, 1:-1])
    fn = fn.at[1:-1, 1:-1].set(0.5 * (hc[1:-1, 1:-1] + hc[2:, 1:-1]) * v[1:-1, 1:-1])
    fe, token = enforce_boundaries(fe, "u", token)
    fn, token = enforce_boundaries(fn, "v", token)

    dh_new = dh.at[1:-1, 1:-1].set(
        -(fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx - (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
    )

    # nonlinear momentum equation
    q = jnp.empty_like(u)
    ke = jnp.empty_like(u)

    # planetary and relative vorticity
    q = q.at[1:-1, 1:-1].set(
        CORIOLIS_PARAM[1:-1, 1:-1]
        + ((v[1:-1, 2:] - v[1:-1, 1:-1]) / dx - (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy)
    )
    # potential vorticity
    q = q.at[1:-1, 1:-1].mul(
        1.0 / (0.25 * (hc[1:-1, 1:-1] + hc[1:-1, 2:] + hc[2:, 1:-1] + hc[2:, 2:]))
    )
    q, token = enforce_boundaries(q, "h", token)

    du_new = du.at[1:-1, 1:-1].set(
        -GRAVITY * (h[1:-1, 2:] - h[1:-1, 1:-1]) / dx
        + 0.5
        * (
            q[1:-1, 1:-1] * 0.5 * (fn[1:-1, 1:-1] + fn[1:-1, 2:])
            + q[:-2, 1:-1] * 0.5 * (fn[:-2, 1:-1] + fn[:-2, 2:])
        )
    )
    dv_new = dv.at[1:-1, 1:-1].set(
        -GRAVITY * (h[2:, 1:-1] - h[1:-1, 1:-1]) / dy
        - 0.5
        * (
            q[1:-1, 1:-1] * 0.5 * (fe[1:-1, 1:-1] + fe[2:, 1:-1])
            + q[1:-1, :-2] * 0.5 * (fe[1:-1, :-2] + fe[2:, :-2])
        )
    )
    ke = ke.at[1:-1, 1:-1].set(
        0.5
        * (
            0.5 * (u[1:-1, 1:-1] ** 2 + u[1:-1, :-2] ** 2)
            + 0.5 * (v[1:-1, 1:-1] ** 2 + v[:-2, 1:-1] ** 2)
        )
    )
    ke, token = enforce_boundaries(ke, "h", token)

    du_new = du_new.at[1:-1, 1:-1].add(-(ke[1:-1, 2:] - ke[1:-1, 1:-1]) / dx)
    dv_new = dv_new.at[1:-1, 1:-1].add(-(ke[2:, 1:-1] - ke[1:-1, 1:-1]) / dy)

    if is_first_step:
        u = u.at[1:-1, 1:-1].add(dt * du_new[1:-1, 1:-1])
        v = v.at[1:-1, 1:-1].add(dt * dv_new[1:-1, 1:-1])
        h = h.at[1:-1, 1:-1].add(dt * dh_new[1:-1, 1:-1])
    else:
        u = u.at[1:-1, 1:-1].add(
            dt
            * (
                ADAMS_BASHFORTH_A * du_new[1:-1, 1:-1]
                + ADAMS_BASHFORTH_B * du[1:-1, 1:-1]
            )
        )
        v = v.at[1:-1, 1:-1].add(
            dt
            * (
                ADAMS_BASHFORTH_A * dv_new[1:-1, 1:-1]
                + ADAMS_BASHFORTH_B * dv[1:-1, 1:-1]
            )
        )
        h = h.at[1:-1, 1:-1].add(
            dt
            * (
                ADAMS_BASHFORTH_A * dh_new[1:-1, 1:-1]
                + ADAMS_BASHFORTH_B * dh[1:-1, 1:-1]
            )
        )

    h, token = enforce_boundaries(h, "h", token)
    u, token = enforce_boundaries(u, "u", token)
    v, token = enforce_boundaries(v, "v", token)

    if LATERAL_VISCOSITY > 0:
        # lateral friction
        fe = fe.at[1:-1, 1:-1].set(
            LATERAL_VISCOSITY * (u[1:-1, 2:] - u[1:-1, 1:-1]) / dx
        )
        fn = fn.at[1:-1, 1:-1].set(
            LATERAL_VISCOSITY * (u[2:, 1:-1] - u[1:-1, 1:-1]) / dy
        )
        fe, token = enforce_boundaries(fe, "u", token)
        fn, token = enforce_boundaries(fn, "v", token)

        u = u.at[1:-1, 1:-1].add(
            dt
            * (
                (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
                + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
            )
        )

        fe = fe.at[1:-1, 1:-1].set(
            LATERAL_VISCOSITY * (v[1:-1, 2:] - u[1:-1, 1:-1]) / dx
        )
        fn = fn.at[1:-1, 1:-1].set(
            LATERAL_VISCOSITY * (v[2:, 1:-1] - u[1:-1, 1:-1]) / dy
        )
        fe, token = enforce_boundaries(fe, "u", token)
        fn, token = enforce_boundaries(fn, "v", token)

        v = v.at[1:-1, 1:-1].add(
            dt
            * (
                (fe[1:-1, 1:-1] - fe[1:-1, :-2]) / dx
                + (fn[1:-1, 1:-1] - fn[:-2, 1:-1]) / dy
            )
        )

    return ModelState(h, u, v, dh_new, du_new, dv_new)
Example #10
0
def enforce_boundaries(arr, grid, token=None):
    """Handle boundary exchange between processors.

    This is where mpi4jax comes in!
    """
    assert grid in ("h", "u", "v")

    # start sending west, go clockwise
    send_order = (
        "west",
        "north",
        "east",
        "south",
    )

    # start receiving east, go clockwise
    recv_order = (
        "east",
        "south",
        "west",
        "north",
    )

    overlap_slices_send = dict(
        south=(1, slice(None), Ellipsis),
        west=(slice(None), 1, Ellipsis),
        north=(-2, slice(None), Ellipsis),
        east=(slice(None), -2, Ellipsis),
    )

    overlap_slices_recv = dict(
        south=(0, slice(None), Ellipsis),
        west=(slice(None), 0, Ellipsis),
        north=(-1, slice(None), Ellipsis),
        east=(slice(None), -1, Ellipsis),
    )

    proc_neighbors = {
        "south": (proc_idx[0] - 1, proc_idx[1]) if proc_idx[0] > 0 else None,
        "west": (proc_idx[0], proc_idx[1] - 1) if proc_idx[1] > 0 else None,
        "north": (proc_idx[0] + 1, proc_idx[1]) if proc_idx[0] < nproc_y - 1 else None,
        "east": (proc_idx[0], proc_idx[1] + 1) if proc_idx[1] < nproc_x - 1 else None,
    }

    if PERIODIC_BOUNDARY_X:
        if proc_idx[1] == 0:
            proc_neighbors["west"] = (proc_idx[0], nproc_x - 1)

        if proc_idx[1] == nproc_x - 1:
            proc_neighbors["east"] = (proc_idx[0], 0)

    if token is None:
        token = jax.lax.create_token()

    for send_dir, recv_dir in zip(send_order, recv_order):
        send_proc = proc_neighbors[send_dir]
        recv_proc = proc_neighbors[recv_dir]

        if send_proc is None and recv_proc is None:
            continue

        if send_proc is not None:
            send_proc = np.ravel_multi_index(send_proc, (nproc_y, nproc_x))

        if recv_proc is not None:
            recv_proc = np.ravel_multi_index(recv_proc, (nproc_y, nproc_x))

        recv_idx = overlap_slices_recv[recv_dir]
        recv_arr = jnp.empty_like(arr[recv_idx])

        send_idx = overlap_slices_send[send_dir]
        send_arr = arr[send_idx]

        if send_proc is None:
            recv_arr, token = mpi4jax.recv(
                recv_arr, source=recv_proc, comm=mpi_comm, token=token
            )
            arr = arr.at[recv_idx].set(recv_arr)
        elif recv_proc is None:
            token = mpi4jax.send(send_arr, dest=send_proc, comm=mpi_comm, token=token)
        else:
            recv_arr, token = mpi4jax.sendrecv(
                send_arr,
                recv_arr,
                source=recv_proc,
                dest=send_proc,
                comm=mpi_comm,
                token=token,
            )
            arr = arr.at[recv_idx].set(recv_arr)

    if not PERIODIC_BOUNDARY_X and grid == "u" and proc_idx[1] == nproc_x - 1:
        arr = arr.at[:, -2].set(0.0)

    if grid == "v" and proc_idx[0] == nproc_y - 1:
        arr = arr.at[-2, :].set(0.0)

    return arr, token
Example #11
0
def _bdf_step(state, fun, jac):
    # print('bdf_step', state.t, state.h)
    # we will try and use the old jacobian unless convergence of newton iteration
    # fails
    updated_jacobian = False
    # initialise step size and try to make the step,
    # iterate, reducing step size until error is in bounds
    step_accepted = False
    y = jnp.empty_like(state.y0)
    d = jnp.empty_like(state.y0)
    n_iter = -1

    # loop until step is accepted
    while_state = [state, step_accepted, updated_jacobian, y, d, n_iter]

    def while_cond(while_state):
        _, step_accepted, _, _, _, _ = while_state
        return step_accepted == False  # noqa: E712

    def while_body(while_state):
        state, step_accepted, updated_jacobian, y, d, n_iter = while_state

        # solve BDF equation using y0 as starting point
        converged, n_iter, y, d, state = _newton_iteration(state, fun)
        not_converged = converged == False  # noqa: E712

        # newton iteration did not converge, but jacobian has already been
        # evaluated so reduce step size by 0.3 (as per [1]) and try again
        state = tree_multimap(
            partial(jnp.where, not_converged * updated_jacobian),
            _update_step_size_and_lu(state, 0.3),
            state,
        )

        # if not_converged * updated_jacobian:
        #    print('not converged, update step size by 0.3')
        # if not_converged * (updated_jacobian == False):
        #    print('not converged, update jacobian')

        # if not converged and jacobian not updated, then update the jacobian and try
        # again
        (state, updated_jacobian) = tree_multimap(
            partial(
                jnp.where, not_converged * (updated_jacobian == False)  # noqa: E712
            ),
            (_update_jacobian(state, jac), True),
            (state, False + updated_jacobian),
        )

        safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
        scale_y = state.atol + state.rtol * jnp.abs(y)

        # combine eq 3, 4 and 6 from [1] to obtain error
        # Note that error = C_k * h^{k+1} y^{k+1}
        # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1}
        error = state.error_const[state.order] * d

        error_norm = rms_norm(error / scale_y)

        # calculate optimal step size factor as per eq 2.46 of [2]
        factor = jnp.maximum(
            MIN_FACTOR, safety * error_norm ** (-1 / (state.order + 1))
        )

        # if converged * (error_norm > 1):
        #    print('converged, but error is too large',error_norm, factor, d, scale_y)

        (state, step_accepted) = tree_multimap(
            partial(jnp.where, converged * (error_norm > 1)),  # noqa: E712
            (_update_step_size_and_lu(state, factor), False),
            (state, converged),
        )

        return [state, step_accepted, updated_jacobian, y, d, n_iter]

    state, step_accepted, updated_jacobian, y, d, n_iter = jax.lax.while_loop(
        while_cond, while_body, while_state
    )

    # take the accepted step
    n_steps = state.n_steps + 1
    t = state.t + state.h

    # a change in order is only done after running at order k for k + 1 steps
    # (see page 83 of [2])
    n_equal_steps = state.n_equal_steps + 1

    state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps)

    state = tree_multimap(
        partial(jnp.where, n_equal_steps < state.order + 1),
        _prepare_next_step(state, d),
        _prepare_next_step_order_change(state, d, y, n_iter),
    )

    return state