class LDS(Env): def __init__(self, state_size=1, action_size=1, A=None, B=None, C=None, seed=0): if A is not None: assert ( A.shape[0] == state_size and A.shape[1] == state_size ), "ERROR: Your input dynamics matrix does not have the correct shape." if B is not None: assert ( B.shape[0] == state_size and B.shape[1] == action_size ), "ERROR: Your input dynamics matrix does not have the correct shape." self.random = Random(seed) self.state_size, self.action_size = state_size, action_size self.A = (A if A is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, state_size))) self.B = (B if B is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, action_size))) self.C = (C if C is not None else jax.numpy.identity(self.state_size)) self.t = 0 self.reset() def step(self, action): self.state = self.A @ self.state + self.B @ action self.obs = self.C @ self.state @jax.jit def dynamics(self, state, action): new_state = self.A @ state + self.B @ action return new_state def reset(self): self.state = jax.random.normal(self.random.generate_key(), shape=(self.state_size, 1)) self.obs = self.C @ self.state
class Deep(Agent): """ Generic deep controller that uses zero-order methods to train on an environment. """ def __init__( self, env: Env, learning_rate: Real = 0.001, gamma: Real = 0.99, max_episode_length: int = 500, seed: int = 0, ) -> None: """ Description: initializes the Deep agent Args: env (Env): a deluca environment learning_rate (Real): gamma (Real): max_episode_length (int): seed (int): Returns: None """ # Create gym and seed numpy self.env = env self.max_episode_length = max_episode_length self.lr = learning_rate self.gamma = gamma self.random = Random(seed) self.reset() def reset(self) -> None: """ Description: reset agent Args: None Returns: None """ # Init weight self.W = jax.random.uniform( self.random.generate_key(), shape=(self.env.state_size, len(self.env.action_space)), minval=0, maxval=1, ) # Keep stats for final print of graph self.episode_rewards = [] self.current_episode_length = 0 self.current_episode_reward = 0 self.episode_rewards = jnp.zeros(self.max_episode_length) self.episode_grads = jnp.zeros((self.max_episode_length, self.W.shape[0], self.W.shape[1])) def policy(self, state: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray: """ Description: Policy that maps state to action parameterized by w Args: state (jnp.ndarray): w (jnp.ndarray): """ z = jnp.dot(state, w) exp = jnp.exp(z) return exp / jnp.sum(exp) def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray: """ Description: Vectorized softmax Jacobian Args: softmax (jnp.ndarray) """ s = softmax.reshape(-1, 1) return jnp.diagflat(s) - jnp.dot(s, s.T) def __call__(self, state: jnp.ndarray): """ Description: provide an action given a state Args: state (jnp.ndarray): Returns: jnp.ndarray: action to take """ self.state = state self.probs = self.policy(state, self.W) self.action = jax.random.choice( self.random.generate_key(), a=self.env.action_space, p=self.probs ) return self.action def feed(self, reward: Real) -> None: """ Description: compute gradient and save with reward in memory for weight updates Args: reward (Real): Returns: None """ dsoftmax = self.softmax_grad(self.probs)[self.action, :] dlog = dsoftmax / self.probs[self.action] grad = self.state.reshape(-1, 1) @ dlog.reshape(1, -1) self.episode_rewards = jax.ops.index_update( self.episode_rewards, self.current_episode_length, reward ) self.episode_grads = jax.ops.index_update( self.episode_grads, self.current_episode_length, grad ) self.current_episode_length += 1 def update(self) -> None: """ Description: update weights Args: None Returns: None """ for i in range(self.current_episode_length): # Loop through everything that happend in the episode and update # towards the log policy gradient times **FUTURE** reward self.W += self.lr * self.episode_grads[i] + jnp.sum( jnp.array( [ r * (self.gamma ** r) for r in self.episode_rewards[i : self.current_episode_length] ] ) ) # reset episode length self.current_episode_length = 0
class MountainCar(Env): def __init__(self, goal_velocity=0, seed=0): self.min_action = -1.0 self.max_action = 1.0 self.min_position = -1.2 self.max_position = 0.6 self.max_speed = 0.07 self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version self.goal_velocity = goal_velocity self.power = 0.0015 self.random = Random(seed) self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32) self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32) self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1, ), dtype=np.float32) self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32) self.reset() @jax.jit def dynamics(self, state, action): position = state[0] velocity = state[1] force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action) velocity += force * self.power - 0.0025 * jnp.cos(3 * position) velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) position += velocity position = jnp.clip(position, self.min_position, self.max_position) reset_velocity = (position == self.min_position) & (velocity < 0) velocity = jax.lax.cond(reset_velocity == 1, lambda x: 0.0, lambda x: x, velocity) # if (position == self.min_position and velocity < 0): velocity = 0 return jnp.array([position, velocity]) def step(self, action): self.state = self.dynamics(self.state, action) position = self.state[0] velocity = self.state[1] # Convert a possible numpy bool to a Python bool. done = (position >= self.goal_position) & (velocity >= self.goal_velocity) reward = 100.0 * done reward -= jnp.power(action, 2) * 0.1 return self.state, reward, done, {} def reset(self): self.state = jnp.array([ jax.random.uniform(self.random.generate_key(), minval=-0.6, maxval=0.4), 0 ]) return self.state def _height(self, xs): return jnp.sin(3 * xs) * 0.45 + 0.55 def render(self, mode="human"): screen_width = 600 screen_height = 400 world_width = self.max_position - self.min_position scale = screen_width / world_width carwidth = 40 carheight = 20 if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(screen_width, screen_height) xs = np.linspace(self.min_position, self.max_position, 100) ys = self._height(xs) xys = list(zip((xs - self.min_position) * scale, ys * scale)) self.track = rendering.make_polyline(xys) self.track.set_linewidth(4) self.viewer.add_geom(self.track) clearance = 10 l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0 car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) car.add_attr(rendering.Transform(translation=(0, clearance))) self.cartrans = rendering.Transform() car.add_attr(self.cartrans) self.viewer.add_geom(car) frontwheel = rendering.make_circle(carheight / 2.5) frontwheel.set_color(0.5, 0.5, 0.5) frontwheel.add_attr( rendering.Transform(translation=(carwidth / 4, clearance))) frontwheel.add_attr(self.cartrans) self.viewer.add_geom(frontwheel) backwheel = rendering.make_circle(carheight / 2.5) backwheel.add_attr( rendering.Transform(translation=(-carwidth / 4, clearance))) backwheel.add_attr(self.cartrans) backwheel.set_color(0.5, 0.5, 0.5) self.viewer.add_geom(backwheel) flagx = (self.goal_position - self.min_position) * scale flagy1 = self._height(self.goal_position) * scale flagy2 = flagy1 + 50 flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) self.viewer.add_geom(flagpole) flag = rendering.FilledPolygon([(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]) flag.set_color(0.8, 0.8, 0) self.viewer.add_geom(flag) pos = self.state[0] self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale) self.cartrans.set_rotation(math.cos(3 * pos)) return self.viewer.render(return_rgb_array=mode == "rgb_array")
class Acrobot(Env): """ Acrobot is a 2-link pendulum with only the second joint actuated. Initially, both links point downwards. The goal is to swing the end-effector at a height at least the length of one link above the base. Both links can swing freely and can pass by each other, i.e., they don't collide when they have the same angle. **STATE:** The state consists of the sin() and cos() of the two rotational joint angles and the joint angular velocities : [cos(theta1) sin(theta1) cos(theta2) sin(theta2) thetaDot1 thetaDot2]. For the first link, an angle of 0 corresponds to the link pointing downwards. The angle of the second link is relative to the angle of the first link. An angle of 0 corresponds to having the same angle between the two links. A state of [1, 0, 1, 0, ..., ...] means that both links point downwards. **ACTIONS:** The action is either applying +1, 0 or -1 torque on the joint between the two pendulum links. **REFERENCE:** .. warning:: This version of the domain uses the Runge-Kutta method for integrating the system dynamics and is more realistic, but also considerably harder than the original version which employs Euler integration, see the AcrobotLegacy class. """ dt = 0.2 LINK_LENGTH_1 = 1.0 # [m] LINK_LENGTH_2 = 1.0 # [m] LINK_MASS_1 = 1.0 #: [kg] mass of link 1 LINK_MASS_2 = 1.0 #: [kg] mass of link 2 LINK_COM_POS_1 = 0.5 #: [m] position of the center of mass of link 1 LINK_COM_POS_2 = 0.5 #: [m] position of the center of mass of link 2 LINK_MOI = 1.0 #: moments of inertia for both links MAX_VEL_1 = 4 * pi MAX_VEL_2 = 9 * pi AVAIL_TORQUE = jnp.array([-1.0, 0.0, +1]) torque_noise_max = 0.0 #: use dynamics equations from the nips paper or the book book_or_nips = "book" action_arrow = None domain_fig = None actions_num = 3 def __init__(self, seed=0, horizon=50): high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32) low = -high self.random = Random(seed) self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.action_space = spaces.Discrete(3) self.action_dim = 1 self.H = horizon self.nsamples = 0 # @jax.jit def _dynamics(state, action): self.nsamples += 1 # Augment the state with our force action so it can be passed to _dsdt augmented_state = jnp.append(state, action) new_state = rk4(self._dsdt, augmented_state, [0, self.dt]) # only care about final timestep of integration returned by integrator new_state = new_state[-1] new_state = new_state[:4] # omit action # ODEINT IS TOO SLOW! # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt]) # self.s_continuous = ns_continuous[-1] # We only care about the state # at the ''final timestep'', self.dt new_state = jax.ops.index_update(new_state, 0, wrap(new_state[0], -pi, pi)) new_state = jax.ops.index_update(new_state, 1, wrap(new_state[1], -pi, pi)) new_state = jax.ops.index_update( new_state, 2, bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1)) new_state = jax.ops.index_update( new_state, 3, bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2)) return new_state # @jax.jit def c(x, u): return u[0]**2 + 1 - jnp.exp(0.5 * cos(x[0]) + 0.5 * cos(x[1]) - 1) self.dynamics = _dynamics self.reward_fn = c self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) self.reset() def reset(self): self.state = jax.random.uniform( self.random.generate_key(), shape=(4, ), minval=-0.1, maxval=0.1 # self.random.generate_key(), shape=(6,), minval=-0.1, maxval=0.1 ) return self.state def step(self, action): # torque = self.AVAIL_TORQUE[action] # discrete action space torque = bound(action, -1.0, 1.0) # continuous action space # Add noise to the force action if self.torque_noise_max > 0: torque += self.np_random.uniform(-self.torque_noise_max, self.torque_noise_max) self.state = self.dynamics(self.state, torque) terminal = self._terminal() # reward = -1.0 + terminal # openAI cost function reward = self.reward_fn(self.state, action) # TODO: should this return self.state (dim 4) or self.observation (dim 6)? return self.state, reward, terminal, {} @property def observation(self): return jnp.array([ cos(self.state[0]), sin(self.state[0]), cos(self.state[1]), sin(self.state[1]), self.state[2], self.state[3], ]) def _terminal(self): return -cos(self.state[0]) - cos(self.state[1] + self.state[0]) > 1.0 def _dsdt(self, augmented_state, t): m1 = self.LINK_MASS_1 m2 = self.LINK_MASS_2 l1 = self.LINK_LENGTH_1 lc1 = self.LINK_COM_POS_1 lc2 = self.LINK_COM_POS_2 I1 = self.LINK_MOI I2 = self.LINK_MOI g = 9.8 a = augmented_state[-1] s = augmented_state[:-1] theta1 = s[0] theta2 = s[1] dtheta1 = s[2] dtheta2 = s[3] d1 = m1 * lc1**2 + m2 * (l1**2 + lc2**2 + 2 * l1 * lc2 * cos(theta2)) + I1 + I2 d2 = m2 * (lc2**2 + l1 * lc2 * cos(theta2)) + I2 phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0) phi1 = (-m2 * l1 * lc2 * dtheta2**2 * sin(theta2) - 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * sin(theta2) + (m1 * lc1 + m2 * l1) * g * cos(theta1 - pi / 2) + phi2) if self.book_or_nips == "nips": # the following line is consistent with the description in the # paper ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2**2 + I2 - d2**2 / d1) else: # the following line is consistent with the java implementation and the # book ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1**2 * sin(theta2) - phi2) / (m2 * lc2**2 + I2 - d2**2 / d1) ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) # 4-state version
class LDS(Env): def __init__(self, state_size=1, action_size=1, A=None, B=None, C=None, seed=0, noise="normal"): if A is not None: assert ( A.shape[0] == state_size and A.shape[1] == state_size ), "ERROR: Your input dynamics matrix does not have the correct shape." if B is not None: assert ( B.shape[0] == state_size and B.shape[1] == action_size ), "ERROR: Your input dynamics matrix does not have the correct shape." self.viewer = None self.random = Random(seed) self.noise = noise self.state_size, self.action_size = state_size, action_size self.proj_vector1 = jax.random.normal(self.random.generate_key(), (state_size, )) self.proj_vector1 = self.proj_vector1 / jnp.linalg.norm( self.proj_vector1) self.proj_vector2 = jax.random.normal(self.random.generate_key(), (state_size, )) self.proj_vector2 = self.proj_vector2 / jnp.linalg.norm( (self.proj_vector2)) print('self.proj_vector1:' + str(self.proj_vector1)) print('self.proj_vector2:' + str(self.proj_vector2)) self.A = (A if A is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, state_size))) self.B = (B if B is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, action_size))) self.C = (C if C is not None else jax.numpy.identity(self.state_size)) self.t = 0 self.reset() def step(self, action): self.state = self.A @ self.state + self.B @ action if self.noise == "normal": self.state += np.random.normal(0, 0.1, size=(self.state_size, self.action_size)) self.obs = self.C @ self.state @jax.jit def dynamics(self, state, action): new_state = self.A @ state + self.B @ action return new_state def reset(self): self.state = jax.random.normal(self.random.generate_key(), shape=(self.state_size, 1)) self.obs = self.C @ self.state def render(self, mode="human"): if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(1000, 1000) self.viewer.set_bounds(-2.5, 2.5, -2.5, 2.5) fname = path.dirname(__file__) + "/classic/assets/lds_arrow.png" self.img = rendering.Image(fname, 0.35, 0.35) self.img.set_color(1.0, 1.0, 1.0) self.imgtrans = rendering.Transform() self.img.add_attr(self.imgtrans) fnamewind = path.dirname(__file__) + "/classic/assets/lds_grid.png" self.imgwind = rendering.Image(fnamewind, 5.0, 5.0) self.imgwind.set_color(0.5, 0.5, 0.5) self.imgtranswind = rendering.Transform() self.imgwind.add_attr(self.imgtranswind) self.viewer.add_onetime(self.imgwind) self.viewer.add_onetime(self.img) cur_x, cur_y = self.imgtrans.translation # new_x, new_y = self.state[0], self.state[1] new_x, new_y = jnp.dot(self.state.squeeze(), self.proj_vector1), jnp.dot( self.state.squeeze(), self.proj_vector2) diff_x = new_x - cur_x diff_y = new_y - cur_y new_rotation = jnp.arctan2(diff_y, diff_x) self.imgtrans.set_translation(new_x, new_y) self.imgtrans.set_rotation(new_rotation) return self.viewer.render(return_rgb_array=mode == "rgb_array") def close(self): if self.viewer: self.viewer.close() self.viewer = None
class Pendulum(Env): max_speed = 8.0 max_torque = 2.0 # gym 2. high = np.array([1.0, 1.0, max_speed]) action_space = spaces.Box(low=-max_torque, high=max_torque, shape=(1, ), dtype=np.float32) observation_space = spaces.Box(low=-high, high=high, dtype=np.float32) metadata = {'render.modes': ['human', 'rgb_array']} def __init__(self, reward_fn=None, seed=0, horizon=50): # self.reward_fn = reward_fn or default_reward_fn self.dt = 0.05 self.viewer = None self.state_size = 2 self.action_size = 1 self.action_dim = 1 # redundant with action_size but needed by ILQR self.H = horizon self.n, self.m = 2, 1 self.angle_normalize = angle_normalize self.nsamples = 0 self.last_u = None self.random = Random(seed) self.reset() # @jax.jit def _dynamics(state, action): self.nsamples += 1 self.last_u = action th, thdot = state g = 10.0 m = 1.0 ell = 1.0 dt = self.dt # Do not limit the control signals action = jnp.clip(action, -self.max_torque, self.max_torque) newthdot = (thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell**2) * action) * dt) newth = th + newthdot * dt newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed) return jnp.reshape(jnp.array([newth, newthdot]), (2, )) @jax.jit def c(x, u): # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2)) return angle_normalize(x[0])**2 + .1 * (u[0]**2) self.reward_fn = reward_fn or c self.dynamics = _dynamics self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) def reset(self): th = jax.random.uniform(self.random.generate_key(), minval=-jnp.pi, maxval=jnp.pi) thdot = jax.random.uniform(self.random.generate_key(), minval=-1.0, maxval=1.0) self.state = jnp.array([th, thdot]) return self.state def render(self, mode="human"): if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(500, 500) self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2) rod = rendering.make_capsule(1, 0.2) rod.set_color(0.8, 0.3, 0.3) self.pole_transform = rendering.Transform() rod.add_attr(self.pole_transform) self.viewer.add_geom(rod) axle = rendering.make_circle(0.05) axle.set_color(0, 0, 0) self.viewer.add_geom(axle) fname = path.join(path.dirname(__file__), "assets/clockwise.png") self.img = rendering.Image(fname, 1.0, 1.0) self.imgtrans = rendering.Transform() self.img.add_attr(self.imgtrans) self.viewer.add_onetime(self.img) self.pole_transform.set_rotation(self.state[0] + np.pi / 2) if self.last_u: self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2) return self.viewer.render(return_rgb_array=mode == "rgb_array")
class DRC(Agent): def __init__( self, A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray = None, K: jnp.ndarray = None, cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None, m: int = 10, h: int = 50, lr_scale: Real = 0.03, decay: bool = True, RM: int = 1000, seed: int = 0, ) -> None: """ Description: Initialize the dynamics of the model. Args: A (jnp.ndarray): system dynamics B (jnp.ndarray): system dynamics C (jnp.ndarray): system dynamics Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. start_time (int): cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]): H (postive int): history of the controller HH (positive int): history of the system lr_scale (Real): decay (boolean): seed (int): """ cost_fn = cost_fn or quad_loss self.random = Random(seed) d_state, d_action = B.shape # State & Action Dimensions C = jnp.identity(d_state) if C is None else C d_obs = C.shape[0] # Observation Dimension self.t = 0 # Time Counter (for decaying learning rate) self.m, self.h = m, h self.lr_scale, self.decay = lr_scale, decay self.RM = RM # Construct truncated markov operator G self.G = jnp.zeros((h, d_obs, d_action)) A_power = jnp.identity(d_state) for i in range(h): self.G = self.G.at[i].set(C @ A_power @ B) A_power = A_power @ A # Model Parameters # initial linear policy / perturbation contributions / bias self.K = K if K is not None else jnp.zeros((d_action, d_obs)) self.M = lr_scale * jax.random.normal( self.random.generate_key(), shape=(m, d_action, d_obs) ) # Past m nature y's such that y_nat[0] is the most recent self.y_nat = jnp.zeros((m, d_obs, 1)) # Past h u's such that u[0] is the most recent self.us = jnp.zeros((h, d_action, 1)) def policy_loss(M, G, y_nat, us): """Surrogate cost function""" def action(obs): """Action function""" return -self.K @ obs + jnp.tensordot(M, y_nat, axes=([0, 2], [0, 1])) final_state = y_nat[0] + jnp.tensordot(G, us, axes=([0, 2], [0, 1])) return cost_fn(final_state, action(final_state)) self.policy_loss = policy_loss self.grad = jit(grad(policy_loss)) def __call__(self, obs: jnp.ndarray) -> jnp.ndarray: """ Description: Return the action based on current state and internal parameters. Args: state (jnp.ndarray): current state Returns: jnp.ndarray: action to take """ # update y_nat self.update_noise(obs) # get action action = self.get_action(obs) # update Parameters self.update_params(obs, action) return action def get_action(self, obs: jnp.ndarray) -> jnp.ndarray: """ Description: get action from state. Args: state (jnp.ndarray): Returns: jnp.ndarray """ return -self.K @ obs + jnp.tensordot(self.M, self.y_nat, axes=([0, 2], [0, 1])) def update(self, obs: jnp.ndarray, u: jnp.ndarray) -> None: self.update_noise(obs) self.update_params(obs, u) def update_noise(self, obs: jnp.ndarray) -> None: y_nat = obs - jnp.tensordot(self.G, self.us, axes=([0, 2], [0, 1])) self.y_nat = jnp.roll(self.y_nat, 1, axis=0) self.y_nat = self.y_nat.at[0].set(y_nat) def update_params(self, obs: jnp.ndarray, u: jnp.ndarray) -> None: """ Description: update agent internal state. Args: state (jnp.ndarray): Returns: None """ # update parameters delta_M = self.grad(self.M, self.G, self.y_nat, self.us) lr = self.lr_scale # lr *= (1/ (self.t+1)) if self.decay else 1 lr = jax.lax.cond(self.decay, lambda x: x * 1 / (self.t + 1), lambda x: 1.0, lr) self.M -= lr * delta_M # if(jnp.linalg.norm(self.M) > self.RM): # self.M *= (self.RM / jnp.linalg.norm(self.M)) self.M = jax.lax.cond( jnp.linalg.norm(self.M) > self.RM, lambda x: x * (self.RM / jnp.linalg.norm(self.M)), lambda x: x, self.M, ) # update us self.us = jnp.roll(self.us, 1, axis=0) self.us = self.us.at[0].set(u) self.t += 1
class MountainCar(Env): metadata = {'render.modes': ['human', 'rgb_array']} def __init__(self, goal_velocity=0, seed=0, horizon=50): self.viewer = None self.min_action = -1.0 self.max_action = 1.0 self.min_position = -1.2 self.max_position = 0.6 self.max_speed = 0.07 self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version self.goal_velocity = goal_velocity self.power = 0.0015 self.H = horizon self.action_dim = 1 self.random = Random(seed) self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32) self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32) self.action_space = spaces.Box( low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32 ) self.observation_space = spaces.Box( low=self.low_state, high=self.high_state, dtype=np.float32 ) self.nsamples = 0 # @jax.jit def _dynamics(state, action): self.nsamples += 1 position = state[0] velocity = state[1] force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action) velocity += force * self.power - 0.0025 * jnp.cos(3 * position) velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) position += velocity position = jnp.clip(position, self.min_position, self.max_position) reset_velocity = (position == self.min_position) & (velocity < 0) # print('state.shape = ' + str(state.shape)) # print('position.shape = ' + str(position.shape)) # print('velocity.shape = ' + str(velocity.shape)) # print('reset_velocity.shape = ' + str(reset_velocity.shape)) velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros((1,)), velocity, lambda x: x) # print('velocity.shape AFTER = ' + str(velocity.shape)) return jnp.reshape(jnp.array([position, velocity]), (2,)) @jax.jit def c(x, u): position, velocity = self.state[0], self.state[1] done = (position >= self.goal_position) & (velocity >= self.goal_velocity) return -100.0 * done + 0.1*(u[0]+1)**2 self.reward_fn = c self.dynamics = _dynamics self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) self.reset() def step(self, action): self.state = self.dynamics(self.state, action) position = self.state[0] velocity = self.state[1] # Convert a possible numpy bool to a Python bool. done = (position >= self.goal_position) & (velocity >= self.goal_velocity) # reward = 100.0 * done # reward -= jnp.power(action, 2) * 0.1 reward = self.reward_fn(self.state, action) return self.state, reward, done, {} def reset(self): self.state = jnp.array( [jax.random.uniform(self.random.generate_key(), minval=-0.6, maxval=0.4), 0] ) return self.state def _height(self, xs): return jnp.sin(3 * xs) * 0.45 + 0.55 def render(self, mode="human"): screen_width = 600 screen_height = 400 world_width = self.max_position - self.min_position scale = screen_width / world_width carwidth = 40 carheight = 20 if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(screen_width, screen_height) xs = np.linspace(self.min_position, self.max_position, 100) ys = self._height(xs) xys = list(zip((xs - self.min_position) * scale, ys * scale)) self.track = rendering.make_polyline(xys) self.track.set_linewidth(4) self.viewer.add_geom(self.track) clearance = 10 l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0 car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)]) car.add_attr(rendering.Transform(translation=(0, clearance))) self.cartrans = rendering.Transform() car.add_attr(self.cartrans) self.viewer.add_geom(car) frontwheel = rendering.make_circle(carheight / 2.5) frontwheel.set_color(0.5, 0.5, 0.5) frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance))) frontwheel.add_attr(self.cartrans) self.viewer.add_geom(frontwheel) backwheel = rendering.make_circle(carheight / 2.5) backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance))) backwheel.add_attr(self.cartrans) backwheel.set_color(0.5, 0.5, 0.5) self.viewer.add_geom(backwheel) flagx = (self.goal_position - self.min_position) * scale flagy1 = self._height(self.goal_position) * scale flagy2 = flagy1 + 50 flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2)) self.viewer.add_geom(flagpole) flag = rendering.FilledPolygon( [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)] ) flag.set_color(0.8, 0.8, 0) self.viewer.add_geom(flag) pos = self.state[0] self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale) self.cartrans.set_rotation(math.cos(3 * pos)) return self.viewer.render(return_rgb_array=mode == "rgb_array")
class Pendulum(Env): max_speed = 8.0 max_torque = 2.0 # gym 2. high = np.array([1.0, 1.0, max_speed]) action_space = gym.spaces.Box(low=-max_torque, high=max_torque, shape=(1,), dtype=np.float32) observation_space = gym.spaces.Box(low=-high, high=high, dtype=np.float32) def __init__(self, reward_fn=None, seed=0): self.reward_fn = reward_fn or default_reward_fn self.dt = 0.05 self.viewer = None self.state_size = 2 self.action_size = 1 self.n, self.m = 2, 1 self.angle_normalize = angle_normalize self.random = Random(seed) self.reset() @jax.jit def dynamics(self, state, action): th, thdot = state g = 10.0 m = 1.0 ell = 1.0 dt = self.dt # Do not limit the control signals action = jnp.clip(action, -self.max_torque, self.max_torque) newthdot = ( thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell ** 2) * action) * dt ) newth = th + newthdot * dt newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed) return jnp.array([newth, newthdot]) def reset(self): th = jax.random.uniform(self.random.generate_key(), minval=-jnp.pi, maxval=jnp.pi) thdot = jax.random.uniform(self.random.generate_key(), minval=-1.0, maxval=1.0) self.state = jnp.array([th, thdot]) return self.state def render(self, mode="human"): if self.viewer is None: from gym.envs.classic_control import rendering self.viewer = rendering.Viewer(500, 500) self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2) rod = rendering.make_capsule(1, 0.2) rod.set_color(0.8, 0.3, 0.3) self.pole_transform = rendering.Transform() rod.add_attr(self.pole_transform) self.viewer.add_geom(rod) axle = rendering.make_circle(0.05) axle.set_color(0, 0, 0) self.viewer.add_geom(axle) fname = path.join(path.dirname(__file__), "assets/clockwise.png") self.img = rendering.Image(fname, 1.0, 1.0) self.imgtrans = rendering.Transform() self.img.add_attr(self.imgtrans) self.viewer.add_onetime(self.img) self.pole_transform.set_rotation(self.state[0] + np.pi / 2) if self.last_u: self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2) return self.viewer.render(return_rgb_array=mode == "rgb_array")