Exemple #1
0
    def __init__(self,
                 A: jnp.ndarray,
                 B: jnp.ndarray,
                 Q: jnp.ndarray = None,
                 R: jnp.ndarray = None,
                 K: jnp.ndarray = None,
                 start_time: int = 0,
                 H: int = 5,
                 lr_scale: Real = 0.005,
                 decay: bool = False,
                 delta: Real = 0.01) -> None:
        """
        Description: Initialize the dynamics of the model.

        Args:
            A (jnp.ndarray): system dynamics
            B (jnp.ndarray): system dynamics
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
            start_time (int):
            H (postive int): history of the controller
            lr_scale (Real):
            decay (boolean):
        """

        self.d_state, self.d_action = B.shape  # State & Action Dimensions

        self.A, self.B = A, B  # System Dynamics

        self.t = 0  # Time Counter (for decaying learning rate)

        self.H = H

        self.lr_scale, self.decay = lr_scale, decay

        self.delta = delta

        # Model Parameters
        # initial linear policy / perturbation contributions / bias
        # TODO: need to address problem of LQR with jax.lax.scan
        self.K = K if K is not None else LQR(self.A, self.B, Q, R).K

        self.M = self.delta * generate_uniform(
            (H, self.d_action, self.d_state))

        # Past H noises ordered increasing in time
        self.noise_history = jnp.zeros((H, self.d_state, 1))

        # past state and past action
        self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros(
            (self.d_action, 1))

        self.eps = generate_uniform((H, H, self.d_action, self.d_state))
        self.eps_bias = generate_uniform((H, self.d_action, 1))

        def grad(M, noise_history, cost):
            return cost * jnp.sum(self.eps, axis=0)

        self.grad = grad
Exemple #2
0
    def __init__(
        self,
        A: jnp.ndarray,
        B: jnp.ndarray,
        Q: jnp.ndarray = None,
        R: jnp.ndarray = None,
        K: jnp.ndarray = None,
        start_time: int = 0,
        cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None,
        H: int = 3,
        HH: int = 2,
        lr_scale: Real = 0.005,
        decay: bool = True,
    ) -> None:
        """
        Description: Initialize the dynamics of the model.

        Args:
            A (jnp.ndarray): system dynamics
            B (jnp.ndarray): system dynamics
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
            start_time (int):
            cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]):
            H (postive int): history of the controller
            HH (positive int): history of the system
            lr_scale (Real):
            lr_scale_decay (Real):
            decay (Real):
        """

        cost_fn = cost_fn or quad_loss

        d_state, d_action = B.shape  # State & Action Dimensions

        self.A, self.B = A, B  # System Dynamics

        self.t = 0  # Time Counter (for decaying learning rate)

        self.H, self.HH = H, HH

        self.lr_scale, self.decay = lr_scale, decay

        self.bias = 0

        # Model Parameters
        # initial linear policy / perturbation contributions / bias
        # TODO: need to address problem of LQR with jax.lax.scan
        self.K = K if K is not None else LQR(self.A, self.B, Q, R).K

        self.M = jnp.zeros((H, d_action, d_state))

        # Past H + HH noises ordered increasing in time
        self.noise_history = jnp.zeros((H + HH, d_state, 1))

        # past state and past action
        self.state, self.action = jnp.zeros((d_state, 1)), jnp.zeros(
            (d_action, 1))

        def last_h_noises():
            """Get noise history"""
            return jax.lax.dynamic_slice_in_dim(self.noise_history, -H, H)

        self.last_h_noises = last_h_noises

        def policy_loss(M, w):
            """Surrogate cost function"""
            def action(state, h):
                """Action function"""
                return -self.K @ state + jnp.tensordot(
                    M,
                    jax.lax.dynamic_slice_in_dim(w, h, H),
                    axes=([0, 2], [0, 1]))

            def evolve(state, h):
                """Evolve function"""
                return self.A @ state + self.B @ action(state, h) + w[h +
                                                                      H], None

            final_state, _ = jax.lax.scan(evolve, np.zeros((d_state, 1)),
                                          np.arange(H - 1))
            return cost_fn(final_state, action(final_state, HH - 1))

        self.policy_loss = policy_loss
        self.grad = jit(grad(policy_loss, (0, 1)))