def __init__(self, A: jnp.ndarray, B: jnp.ndarray, Q: jnp.ndarray = None, R: jnp.ndarray = None, K: jnp.ndarray = None, start_time: int = 0, H: int = 5, lr_scale: Real = 0.005, decay: bool = False, delta: Real = 0.01) -> None: """ Description: Initialize the dynamics of the model. Args: A (jnp.ndarray): system dynamics B (jnp.ndarray): system dynamics Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. start_time (int): H (postive int): history of the controller lr_scale (Real): decay (boolean): """ self.d_state, self.d_action = B.shape # State & Action Dimensions self.A, self.B = A, B # System Dynamics self.t = 0 # Time Counter (for decaying learning rate) self.H = H self.lr_scale, self.decay = lr_scale, decay self.delta = delta # Model Parameters # initial linear policy / perturbation contributions / bias # TODO: need to address problem of LQR with jax.lax.scan self.K = K if K is not None else LQR(self.A, self.B, Q, R).K self.M = self.delta * generate_uniform( (H, self.d_action, self.d_state)) # Past H noises ordered increasing in time self.noise_history = jnp.zeros((H, self.d_state, 1)) # past state and past action self.state, self.action = jnp.zeros((self.d_state, 1)), jnp.zeros( (self.d_action, 1)) self.eps = generate_uniform((H, H, self.d_action, self.d_state)) self.eps_bias = generate_uniform((H, self.d_action, 1)) def grad(M, noise_history, cost): return cost * jnp.sum(self.eps, axis=0) self.grad = grad
def __init__( self, A: jnp.ndarray, B: jnp.ndarray, Q: jnp.ndarray = None, R: jnp.ndarray = None, K: jnp.ndarray = None, start_time: int = 0, cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None, H: int = 3, HH: int = 2, lr_scale: Real = 0.005, decay: bool = True, ) -> None: """ Description: Initialize the dynamics of the model. Args: A (jnp.ndarray): system dynamics B (jnp.ndarray): system dynamics Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. start_time (int): cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]): H (postive int): history of the controller HH (positive int): history of the system lr_scale (Real): lr_scale_decay (Real): decay (Real): """ cost_fn = cost_fn or quad_loss d_state, d_action = B.shape # State & Action Dimensions self.A, self.B = A, B # System Dynamics self.t = 0 # Time Counter (for decaying learning rate) self.H, self.HH = H, HH self.lr_scale, self.decay = lr_scale, decay self.bias = 0 # Model Parameters # initial linear policy / perturbation contributions / bias # TODO: need to address problem of LQR with jax.lax.scan self.K = K if K is not None else LQR(self.A, self.B, Q, R).K self.M = jnp.zeros((H, d_action, d_state)) # Past H + HH noises ordered increasing in time self.noise_history = jnp.zeros((H + HH, d_state, 1)) # past state and past action self.state, self.action = jnp.zeros((d_state, 1)), jnp.zeros( (d_action, 1)) def last_h_noises(): """Get noise history""" return jax.lax.dynamic_slice_in_dim(self.noise_history, -H, H) self.last_h_noises = last_h_noises def policy_loss(M, w): """Surrogate cost function""" def action(state, h): """Action function""" return -self.K @ state + jnp.tensordot( M, jax.lax.dynamic_slice_in_dim(w, h, H), axes=([0, 2], [0, 1])) def evolve(state, h): """Evolve function""" return self.A @ state + self.B @ action(state, h) + w[h + H], None final_state, _ = jax.lax.scan(evolve, np.zeros((d_state, 1)), np.arange(H - 1)) return cost_fn(final_state, action(final_state, HH - 1)) self.policy_loss = policy_loss self.grad = jit(grad(policy_loss, (0, 1)))