class ReacherState(Obj): """ReacherState. Attributes: arr: h: """ arr: jnp.ndarray = field(jaxed=True) h: float = field(0.0, jaxed=True) def flatten(self): """flatten. Returns: """ return self.arr # TODO(dsuo): this should be a classmethod. def unflatten(self, arr): """unflatten. Args: arr: Returns: """ return ReacherState(arr=arr, h=self.h)
class MountainCar(Env): """MountainCar.""" key: jnp.ndarray = field(jaxed=False) goal_velocity: float = field(0.0, jaxed=False) min_action: float = field(-1.0, jaxed=False) max_action: float = field(1.0, jaxed=False) min_position: float = field(-1.2, jaxed=False) max_position: float = field(0.6, jaxed=False) max_speed: float = field(0.07, jaxed=False) goal_position: float = field(0.5, jaxed=False) power = 0.0015 low_state: jnp.ndarray = field(jaxed=False) high_state: jnp.ndarray = field(jaxed=False) def setup(self): """setup.""" self.low_state = jnp.array([self.min_position, -self.max_speed]) self.high_state = jnp.array([self.max_position, self.max_speed]) if self.key is None: self.key = jax.random.PRNGKey(0) def init(self): """init. Returns: """ state = jnp.array( [jax.random.uniform(self.key, min_val=-0.6, maxval=0.4), 0]) return state, state def __call__(self, state, action): """__call__. Args: state: action: Returns: """ position, velocity = state force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action) velocity += force * self.power - 0.0025 * jnp.cos(3 * position) velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) position += velocity position = jnp.clip(position, self.min_position, self.max_position) reset_velocity = (position == self.min_position) & (velocity < 0) velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros( (1, )), velocity, lambda x: x) new_state = jnp.reshape(jnp.array([position, velocity]), (2, )) return new_state, new_state
class Random(Agent): seed: int = field(0, jaxed=False) func: Callable = field(lambda key: 0.0, jaxed=False) def init(self): return jax.random.PRNGKey(self.seed) def __call__(self, state, obs): key, subkey = jax.random.split(state) return subkey, self.func(key)
class BangBang(Agent): target: float = field(0.0, jaxed=False) min_action: float = field(0.0, jaxed=False) max_action: float = field(100.0, jaxed=False) def init(self): return None def __call__(self, state, obs): """Assume observation is env state""" return state, self.min_action if obs > self.target else self.max_action
class PlanarQuadrotorState(Obj): """PlanarQuadrotorState.""" arr: jnp.ndarray = field(jaxed=True) h: float = field(0.0, jaxed=True) last_action: jnp.ndarray = field(jaxed=True) def flatten(self): return self.arr def unflatten(self, arr): return PlanarQuadrotorState(arr=arr, h=self.h, last_action=self.last_action)
class Zero(Agent): action_dim: int = field(1, jaxed=False) def init(self): return None def __call__(self, state, obs): return state, jnp.zeros(self.action_dim)
class LDS(Env): """LDS.""" A: jnp.array = field(jaxed=False) B: jnp.array = field(jaxed=False) C: jnp.array = field(jaxed=False) key: int = field(jax.random.PRNGKey(0), jaxed=False) state_size: int = field(1, jaxed=False) action_size: int = field(1, jaxed=False) def init(self): """init. Returns: """ state = jax.random.normal(self.key, shape=(self.state_size, 1)) return state, state def __call__(self, state, action): """__call__. Args: state: action: Returns: """ new_state = self.A @ state + self.B @ action return new_state, self.C @ new_state
class Cartpole(Env): """Cartpole.""" m: float = field(0.1, jaxed=False) M: float = field(1.0, jaxed=False) l: float = field(1.0, jaxed=False) g: float = field(9.81, jaxed=False) dt: float = field(0.02, jaxed=False) H: int = field(10, jaxed=False) goal_state: jnp.ndarray = field(jaxed=False) dynamics: bool = field(False, jaxed=False) def setup(self): """setup.""" if self.goal_state is None: self.goal_state = jnp.array([0.0, 0.0, 0.0, 0.0]) def init(self): """init. Returns: """ state = CartpoleState(arr=jnp.array([0.0, 0.0, 0.0, 0.0])) return state, state def __call__(self, state, action): """__call__. Args: state: action: Returns: """ A = jnp.array([ [1.0, 0.0, self.dt, 0.0], [0.0, 1.0, 0.0, self.dt], [0.0, self.dt * self.m * self.g / self.M, 1.0, 0.0], [ 0.0, self.dt * (self.m + self.M) * self.g / (self.M * self.l), 0.0, 1.0 ], ]) B = (jnp.array([[0.0], [0.0], [self.dt / self.M], [self.dt / (self.M * self.l)]]), ) arr = (A @ state.arr + B @ (action + state.offset), ) return state.replace(arr=arr, h=state.h + 1), arr
class Pendulum(Env): """Pendulum.""" m: float = field(1.0, jaxed=False) l: float = field(1.0, jaxed=False) g: float = field(9.81, jaxed=False) max_torque: float = field(1.0, jaxed=False) dt: float = field(0.02, jaxed=False) H: int = field(300, jaxed=False) goal_state: jnp.ndarray = field(jaxed=False) def init(self): """init. Returns: """ state = PendulumState(arr=jnp.array([0.0, 1.0, 0.0])) return state, state def setup(self): if self.goal_state is None: self.goal_state = jnp.array([0., -1., 0.]) def __call__(self, state, action): """__call__. Args: state: action: Returns: """ sin, cos, _ = state.arr action = self.max_torque * jnp.tanh(action[0]) newthdot = jnp.arctan2(sin, cos) + ( -3.0 * self.g / (2.0 * self.l) * jnp.sin(jnp.arctan2(sin, cos) + jnp.pi) + 3.0 / (self.m * self.l**2) * action) newth = jnp.arctan2(sin, cos) + newthdot * self.dt newsin, newcos = jnp.sin(newth), jnp.cos(newth) arr = jnp.array([newsin, newcos, newthdot]) return PendulumState(arr=arr, h=state.h + 1), arr
class PID(Agent): K_P: float = field(0.0, jaxed=True) K_I: float = field(0.0, jaxed=True) K_D: float = field(0.0, jaxed=True) RC: float = field(0.5, jaxed=False) dt: float = field(0.03, jaxed=False) decay: float = field(jaxed=False) def init(self): return PIDState() def setup(self): self.decay = self.dt / (self.dt + self.RC) def __call__(self, state, obs): """Assume observation is err""" P = obs I = state.I + self.decay * (obs - state.I) D = state.D + self.decay * (obs - state.P - state.D) action = self.K_P * P + self.K_I * I + self.K_D * D return state.replace(P=P, I=I, D=D), action
class BraxEnv(Env): """Brax.""" sys: brax.System = field(jaxed=False) env: brax_envs.Env = field(jaxed=False) @classmethod def from_env(cls, env): return cls.create(env=env, sys=env.sys) @classmethod def from_name(cls, name): return cls.from_env(brax_envs.create(env_name=name)) @classmethod def from_config(cls, config): return cls.create(sys=brax.System(config)) def setup(self): """setup.""" if self.sys is None: raise ValueError("BraxEnv requires `sys` or `env` field specified.") def init(self, rng=None): """init. Returns: """ if self.env is None: state = self.sys.default_qp() obs = state if self.env is None else self.env._get_obs( state, self.sys.info(state)) return state, obs else: state = self.env.reset(rng=rng) return state.qp, state.obs def reset(self, rng=None): return self.init(rng=rng) def __call__(self, state, action): """__call__. Args: state: action: Returns: """ state, info = self.sys.step(state, action) obs = state if self.env is None else self.env._get_obs( state, info) return state, obs def render(self, states): """render. Args: states: Returns: """ return HTML(html.render(self.sys, states))
class CartpoleState(Obj): """CartpoleState.""" arr: jnp.ndarray = field(jaxed=True) h: int = field(0, jaxed=True) offset: float = field(0.0, jaxed=False)
class PlanarQuadrotor(Env): """PlanarQuadrotor.""" m: float = field(0.1, jaxed=False) l: float = field(0.2, jaxed=False) g: float = field(9.81, jaxed=False) dt: float = field(0.05, jaxed=False) H: int = field(100, jaxed=False) wind: float = field(0.0, jaxed=False) wind_func: Callable[[float, float, float], List[float]] = field(dissipative, jaxed=False) goal_state: jnp.ndarray = field(jaxed=False) goal_action: jnp.ndarray = field(jaxed=False) state_dim: float = field(6, jaxed=False) action_dim: float = field(2, jaxed=False) def setup(self): """setup.""" self.goal_action = jnp.array( [self.m * self.g / 2.0, self.m * self.g / 2.0]) if self.goal_state is None: self.goal_state = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) def _wind_field(self, x, y): """_wind_field. Args: x: y: Returns: """ return self.wind_func(x, y, self.wind) def init(self): """init.""" new_state = PlanarQuadrotorState(arr=jnp.array( [1., 1., 0., 0., 0., 0.]), last_action=jnp.array([0., 0.])) return new_state, new_state.arr def __call__(self, state, action): """__call__.""" x, y, th, xdot, ydot, thdot = state.arr u1, u2 = action m, g, l, dt = self.m, self.g, self.l, self.dt wind = self._wind_field(x, y) xddot = -(u1 + u2) * jnp.sin(th) / m + wind[0] / m yddot = (u1 + u2) * jnp.cos(th) / m - g + wind[1] / m thddot = l * (u2 - u1) / (m * l**2) state_dot = jnp.array([xdot, ydot, thdot, xddot, yddot, thddot]) arr = state.arr + state_dot * dt return state.replace(arr=arr, last_action=action, h=state.h + 1), arr
class Reacher(Env): """Reacher.""" m1: float = field(1.0, jaxed=False) m2: float = field(1.0, jaxed=False) l1: float = field(1.0, jaxed=False) l2: float = field(1.0, jaxed=False) g: float = field(0.0, jaxed=False) max_torque: float = field(1.0, jaxed=False) dt: float = field(0.01, jaxed=False) H: int = field(200, jaxed=False) goal_coord: jnp.ndarray = field(jaxed=False) state_dim: float = field(6, jaxed=False) action_dim: float = field(2, jaxed=False) def init(self): """init. Returns: """ initial_th = (jnp.pi / 4, jnp.pi / 2) state = ReacherState(arr=jnp.array([ *initial_th, 0.0, 0.0, self.l1 * jnp.cos(initial_th[0]) + self.l2 * jnp.cos(initial_th[0] + initial_th[1]) - self.goal_coord[0], self.l1 * jnp.sin(initial_th[0]) + self.l2 * jnp.sin(initial_th[0] + initial_th[1]) - self.goal_coord[1], ])) return state, state def setup(self): if self.goal_coord is None: self.goal_coord = jnp.array([0., 1.8]) def __call__(self, state, action): """__call__. Args: state: action: Returns: """ m1, m2, l1, l2, g = self.m1, self.m2, self.l1, self.l2, self.g th1, th2, dth1, dth2, Dx, Dy = state.arr t1, t2 = action a11 = (m1 + m2) * l1**2 + m2 * l2**2 + 2 * m2 * l1 * l2 * jnp.cos(th2) a12 = m2 * l2**2 + m2 * l1 * l2 * jnp.cos(th2) a22 = m2 * l2**2 b1 = (t1 + m2 * l1 * l2 * (2 * dth1 + dth2) * dth2 * jnp.sin(th2) - m2 * l2 * g * jnp.sin(th1 + th2) - (m1 + m2) * l1 * g * jnp.sin(th1)) b2 = t2 - m2 * l1 * l2 * dth1**2 * jnp.sin( th2) - m2 * l2 * g * jnp.sin(th1 + th2) A, b = jnp.array([[a11, a12], [a12, a22]]), jnp.array([b1, b2]) ddth1, ddth2 = jnp.linalg.inv(A) @ b th1, th2 = th1 + dth1 * self.dt, th2 + dth2 * self.dt dth1, dth2 = dth1 + ddth1 * self.dt, dth2 + ddth2 * self.dt Dx, Dy = ( l1 * jnp.cos(th1) + l2 * jnp.cos(th1 + th2) - self.goal_coord[0], l1 * jnp.sin(th1) + l2 * jnp.sin(th1 + th2) - self.goal_coord[1], ) arr = jnp.array([th1, th2, dth1, dth2, Dx, Dy]) new_state = state.replace(arr=arr, h=state.h + 1) return new_state, arr
class Acrobot(Env): """Acrobot.""" key: jnp.ndarray = field(jaxed=False) dt: float = field(0.2, jaxed=False) LINK_LENGTH_1: float = field(1.0, jaxed=False) LINK_LENGTH_2: float = field(1.0, jaxed=False) LINK_MASS_1: float = field(1.0, jaxed=False) LINK_MASS_2: float = field(1.0, jaxed=False) LINK_COM_POS_1: float = field(0.0, jaxed=False) LINK_COM_POS_2: float = field(0.0, jaxed=False) LINK_MOI: float = field(1.0) MAX_VEL_1: float = field(4 * jnp.pi, jaxed=False) MAX_VEL_2: float = field(9 * jnp.pi, jaxed=False) AVAIL_TORQUE: jnp.ndarray = field(jaxed=False) torque_noise_max: float = field(0.0, jaxed=False) high: jnp.ndarray = field(jaxed=False) low: jnp.ndarray = field(jaxed=False) def setup(self): self.high = jnp.array( [1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2]) self.low = -self.high if self.key is None: self.key = jax.random.PRNGKey(0) if self.AVAIL_TORQUE is None: self.AVAIL_TORQUE = jnp.array([-1.0, 0.0, +1]) def init(self): # TODO(dsuo): to implement pass def __call__(self, state, action): augmented_state = jnp.append(state, action) new_state = rk4(self._dsdt, augmented_state, [0, self.dt]) # only care about final timestep of integration returned by integrator new_state = new_state[-1] new_state = new_state[:4] # omit action # ODEINT IS TOO SLOW! # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt]) # self.s_continuous = ns_continuous[-1] # We only care about the state # at the ''final timestep'', self.dt new_state = new_state.at[0].set(wrap(new_state[0], -jnp.pi, jnp.pi)) new_state = new_state.at[1].set(wrap(new_state[1], -jnp.pi, jnp.pi)) new_state = new_state.at[2].set( bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1)) new_state = new_state.at[3].set( bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2)) return ( new_state, jnp.array([ jnp.cos(new_state[0]), jnp.sin(new_state[0]), jnp.cos(new_state[1]), jnp.sin(new_state[1]), new_state[2], new_state[3], ]), ) def _dsdt(self, augmented_state, t): m1 = self.LINK_MASS_1 m2 = self.LINK_MASS_2 l1 = self.LINK_LENGTH_1 lc1 = self.LINK_COM_POS_1 lc2 = self.LINK_COM_POS_2 I1 = self.LINK_MOI I2 = self.LINK_MOI g = 9.8 a = augmented_state[-1] s = augmented_state[:-1] theta1 = s[0] theta2 = s[1] dtheta1 = s[2] dtheta2 = s[3] d1 = m1 * lc1**2 + m2 * (l1**2 + lc2**2 + 2 * l1 * lc2 * jnp.cos(theta2)) + I1 + I2 d2 = m2 * (lc2**2 + l1 * lc2 * jnp.cos(theta2)) + I2 phi2 = m2 * lc2 * g * jnp.cos(theta1 + theta2 - jnp.pi / 2.0) phi1 = (-m2 * l1 * lc2 * dtheta2**2 * jnp.sin(theta2) - 2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * jnp.sin(theta2) + (m1 * lc1 + m2 * l1) * g * jnp.cos(theta1 - jnp.pi / 2) + phi2) ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2**2 + I2 - d2**2 / d1) # if self.book_or_nips == "nips": # the following line is consistent with the description in the # paper # ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) # else: # the following line is consistent with the java implementation and the # book # ddtheta2 = ( # a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * jnp.sin(theta2) - phi2 # ) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1) ddtheta1 = -(d2 * ddtheta2 + phi1) / d1 return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0) # 4-state version
class PIDState(Obj): P: float = field(0.0, jaxed=True) I: float = field(0.0, jaxed=True) D: float = field(0.0, jaxed=True)
class PendulumState(Obj): """PendulumState.""" arr: jnp.ndarray = field(jaxed=True) h: int = field(0, jaxed=True)