Example #1
0
class ReacherState(Obj):
    """ReacherState.

  Attributes:
    arr:
    h:
  """
    arr: jnp.ndarray = field(jaxed=True)
    h: float = field(0.0, jaxed=True)

    def flatten(self):
        """flatten.

    Returns:

    """
        return self.arr

    # TODO(dsuo): this should be a classmethod.
    def unflatten(self, arr):
        """unflatten.

    Args:
      arr:

    Returns:

    """
        return ReacherState(arr=arr, h=self.h)
Example #2
0
class MountainCar(Env):
    """MountainCar."""
    key: jnp.ndarray = field(jaxed=False)
    goal_velocity: float = field(0.0, jaxed=False)
    min_action: float = field(-1.0, jaxed=False)
    max_action: float = field(1.0, jaxed=False)
    min_position: float = field(-1.2, jaxed=False)
    max_position: float = field(0.6, jaxed=False)
    max_speed: float = field(0.07, jaxed=False)
    goal_position: float = field(0.5, jaxed=False)
    power = 0.0015

    low_state: jnp.ndarray = field(jaxed=False)
    high_state: jnp.ndarray = field(jaxed=False)

    def setup(self):
        """setup."""
        self.low_state = jnp.array([self.min_position, -self.max_speed])
        self.high_state = jnp.array([self.max_position, self.max_speed])
        if self.key is None:
            self.key = jax.random.PRNGKey(0)

    def init(self):
        """init.

    Returns:

    """
        state = jnp.array(
            [jax.random.uniform(self.key, min_val=-0.6, maxval=0.4), 0])
        return state, state

    def __call__(self, state, action):
        """__call__.

    Args:
      state:
      action:

    Returns:

    """
        position, velocity = state

        force = jnp.minimum(jnp.maximum(action, self.min_action),
                            self.max_action)

        velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
        velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

        position += velocity
        position = jnp.clip(position, self.min_position, self.max_position)
        reset_velocity = (position == self.min_position) & (velocity < 0)
        velocity = jax.lax.cond(reset_velocity[0], velocity,
                                lambda x: jnp.zeros(
                                    (1, )), velocity, lambda x: x)
        new_state = jnp.reshape(jnp.array([position, velocity]), (2, ))

        return new_state, new_state
Example #3
0
class Random(Agent):
    seed: int = field(0, jaxed=False)
    func: Callable = field(lambda key: 0.0, jaxed=False)

    def init(self):
        return jax.random.PRNGKey(self.seed)

    def __call__(self, state, obs):
        key, subkey = jax.random.split(state)

        return subkey, self.func(key)
Example #4
0
class BangBang(Agent):
    target: float = field(0.0, jaxed=False)
    min_action: float = field(0.0, jaxed=False)
    max_action: float = field(100.0, jaxed=False)

    def init(self):
        return None

    def __call__(self, state, obs):
        """Assume observation is env state"""
        return state, self.min_action if obs > self.target else self.max_action
Example #5
0
class PlanarQuadrotorState(Obj):
    """PlanarQuadrotorState."""
    arr: jnp.ndarray = field(jaxed=True)
    h: float = field(0.0, jaxed=True)
    last_action: jnp.ndarray = field(jaxed=True)

    def flatten(self):
        return self.arr

    def unflatten(self, arr):
        return PlanarQuadrotorState(arr=arr,
                                    h=self.h,
                                    last_action=self.last_action)
Example #6
0
class Zero(Agent):
    action_dim: int = field(1, jaxed=False)

    def init(self):
        return None

    def __call__(self, state, obs):
        return state, jnp.zeros(self.action_dim)
Example #7
0
class LDS(Env):
    """LDS."""
    A: jnp.array = field(jaxed=False)
    B: jnp.array = field(jaxed=False)
    C: jnp.array = field(jaxed=False)
    key: int = field(jax.random.PRNGKey(0), jaxed=False)
    state_size: int = field(1, jaxed=False)
    action_size: int = field(1, jaxed=False)

    def init(self):
        """init.

    Returns:

    """
        state = jax.random.normal(self.key, shape=(self.state_size, 1))
        return state, state

    def __call__(self, state, action):
        """__call__.

    Args:
      state:
      action:

    Returns:

    """
        new_state = self.A @ state + self.B @ action

        return new_state, self.C @ new_state
Example #8
0
class Cartpole(Env):
    """Cartpole."""
    m: float = field(0.1, jaxed=False)
    M: float = field(1.0, jaxed=False)
    l: float = field(1.0, jaxed=False)
    g: float = field(9.81, jaxed=False)
    dt: float = field(0.02, jaxed=False)
    H: int = field(10, jaxed=False)
    goal_state: jnp.ndarray = field(jaxed=False)
    dynamics: bool = field(False, jaxed=False)

    def setup(self):
        """setup."""
        if self.goal_state is None:
            self.goal_state = jnp.array([0.0, 0.0, 0.0, 0.0])

    def init(self):
        """init.

    Returns:

    """
        state = CartpoleState(arr=jnp.array([0.0, 0.0, 0.0, 0.0]))
        return state, state

    def __call__(self, state, action):
        """__call__.

    Args:
      state:
      action:

    Returns:

    """
        A = jnp.array([
            [1.0, 0.0, self.dt, 0.0],
            [0.0, 1.0, 0.0, self.dt],
            [0.0, self.dt * self.m * self.g / self.M, 1.0, 0.0],
            [
                0.0, self.dt * (self.m + self.M) * self.g / (self.M * self.l),
                0.0, 1.0
            ],
        ])
        B = (jnp.array([[0.0], [0.0], [self.dt / self.M],
                        [self.dt / (self.M * self.l)]]), )
        arr = (A @ state.arr + B @ (action + state.offset), )
        return state.replace(arr=arr, h=state.h + 1), arr
Example #9
0
class Pendulum(Env):
    """Pendulum."""
    m: float = field(1.0, jaxed=False)
    l: float = field(1.0, jaxed=False)
    g: float = field(9.81, jaxed=False)
    max_torque: float = field(1.0, jaxed=False)
    dt: float = field(0.02, jaxed=False)
    H: int = field(300, jaxed=False)
    goal_state: jnp.ndarray = field(jaxed=False)

    def init(self):
        """init.

    Returns:

    """
        state = PendulumState(arr=jnp.array([0.0, 1.0, 0.0]))
        return state, state

    def setup(self):
        if self.goal_state is None:
            self.goal_state = jnp.array([0., -1., 0.])

    def __call__(self, state, action):
        """__call__.

    Args:
      state:
      action:

    Returns:

    """
        sin, cos, _ = state.arr
        action = self.max_torque * jnp.tanh(action[0])
        newthdot = jnp.arctan2(sin, cos) + (
            -3.0 * self.g /
            (2.0 * self.l) * jnp.sin(jnp.arctan2(sin, cos) + jnp.pi) + 3.0 /
            (self.m * self.l**2) * action)
        newth = jnp.arctan2(sin, cos) + newthdot * self.dt
        newsin, newcos = jnp.sin(newth), jnp.cos(newth)
        arr = jnp.array([newsin, newcos, newthdot])
        return PendulumState(arr=arr, h=state.h + 1), arr
Example #10
0
class PID(Agent):
    K_P: float = field(0.0, jaxed=True)
    K_I: float = field(0.0, jaxed=True)
    K_D: float = field(0.0, jaxed=True)
    RC: float = field(0.5, jaxed=False)
    dt: float = field(0.03, jaxed=False)
    decay: float = field(jaxed=False)

    def init(self):
        return PIDState()

    def setup(self):
        self.decay = self.dt / (self.dt + self.RC)

    def __call__(self, state, obs):
        """Assume observation is err"""
        P = obs
        I = state.I + self.decay * (obs - state.I)
        D = state.D + self.decay * (obs - state.P - state.D)

        action = self.K_P * P + self.K_I * I + self.K_D * D

        return state.replace(P=P, I=I, D=D), action
Example #11
0
class BraxEnv(Env):
  """Brax."""
  sys: brax.System = field(jaxed=False)
  env: brax_envs.Env = field(jaxed=False)

  @classmethod
  def from_env(cls, env):
    return cls.create(env=env, sys=env.sys)

  @classmethod
  def from_name(cls, name):
    return cls.from_env(brax_envs.create(env_name=name))

  @classmethod
  def from_config(cls, config):
    return cls.create(sys=brax.System(config))

  def setup(self):
    """setup."""
    if self.sys is None:
      raise ValueError("BraxEnv requires `sys` or `env` field specified.")

  def init(self, rng=None):
    """init.

    Returns:

    """
    if self.env is None:
      state = self.sys.default_qp()
      obs = state if self.env is None else self.env._get_obs(
          state, self.sys.info(state))
      return state, obs
    else:
      state = self.env.reset(rng=rng)
      return state.qp, state.obs

  def reset(self, rng=None):
    return self.init(rng=rng)

  def __call__(self, state, action):
    """__call__.

    Args:
      state:
      action:

    Returns:

    """
    state, info = self.sys.step(state, action)
    obs = state if self.env is None else self.env._get_obs(
        state, info)
    return state, obs

  def render(self, states):
    """render.

    Args:
      states:

    Returns:

    """
    return HTML(html.render(self.sys, states))
Example #12
0
class CartpoleState(Obj):
    """CartpoleState."""
    arr: jnp.ndarray = field(jaxed=True)
    h: int = field(0, jaxed=True)
    offset: float = field(0.0, jaxed=False)
Example #13
0
class PlanarQuadrotor(Env):
    """PlanarQuadrotor."""
    m: float = field(0.1, jaxed=False)
    l: float = field(0.2, jaxed=False)
    g: float = field(9.81, jaxed=False)
    dt: float = field(0.05, jaxed=False)
    H: int = field(100, jaxed=False)
    wind: float = field(0.0, jaxed=False)
    wind_func: Callable[[float, float, float],
                        List[float]] = field(dissipative, jaxed=False)
    goal_state: jnp.ndarray = field(jaxed=False)
    goal_action: jnp.ndarray = field(jaxed=False)
    state_dim: float = field(6, jaxed=False)
    action_dim: float = field(2, jaxed=False)

    def setup(self):
        """setup."""
        self.goal_action = jnp.array(
            [self.m * self.g / 2.0, self.m * self.g / 2.0])
        if self.goal_state is None:
            self.goal_state = jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0])

    def _wind_field(self, x, y):
        """_wind_field.

    Args:
      x:
      y:

    Returns:

    """
        return self.wind_func(x, y, self.wind)

    def init(self):
        """init."""
        new_state = PlanarQuadrotorState(arr=jnp.array(
            [1., 1., 0., 0., 0., 0.]),
                                         last_action=jnp.array([0., 0.]))
        return new_state, new_state.arr

    def __call__(self, state, action):
        """__call__."""
        x, y, th, xdot, ydot, thdot = state.arr
        u1, u2 = action
        m, g, l, dt = self.m, self.g, self.l, self.dt
        wind = self._wind_field(x, y)
        xddot = -(u1 + u2) * jnp.sin(th) / m + wind[0] / m
        yddot = (u1 + u2) * jnp.cos(th) / m - g + wind[1] / m
        thddot = l * (u2 - u1) / (m * l**2)
        state_dot = jnp.array([xdot, ydot, thdot, xddot, yddot, thddot])
        arr = state.arr + state_dot * dt
        return state.replace(arr=arr, last_action=action, h=state.h + 1), arr
Example #14
0
class Reacher(Env):
    """Reacher."""
    m1: float = field(1.0, jaxed=False)
    m2: float = field(1.0, jaxed=False)
    l1: float = field(1.0, jaxed=False)
    l2: float = field(1.0, jaxed=False)
    g: float = field(0.0, jaxed=False)
    max_torque: float = field(1.0, jaxed=False)
    dt: float = field(0.01, jaxed=False)
    H: int = field(200, jaxed=False)
    goal_coord: jnp.ndarray = field(jaxed=False)
    state_dim: float = field(6, jaxed=False)
    action_dim: float = field(2, jaxed=False)

    def init(self):
        """init.

    Returns:

    """
        initial_th = (jnp.pi / 4, jnp.pi / 2)
        state = ReacherState(arr=jnp.array([
            *initial_th,
            0.0,
            0.0,
            self.l1 * jnp.cos(initial_th[0]) +
            self.l2 * jnp.cos(initial_th[0] + initial_th[1]) -
            self.goal_coord[0],
            self.l1 * jnp.sin(initial_th[0]) +
            self.l2 * jnp.sin(initial_th[0] + initial_th[1]) -
            self.goal_coord[1],
        ]))
        return state, state

    def setup(self):
        if self.goal_coord is None:
            self.goal_coord = jnp.array([0., 1.8])

    def __call__(self, state, action):
        """__call__.

    Args:
      state:
      action:

    Returns:

    """
        m1, m2, l1, l2, g = self.m1, self.m2, self.l1, self.l2, self.g
        th1, th2, dth1, dth2, Dx, Dy = state.arr
        t1, t2 = action

        a11 = (m1 + m2) * l1**2 + m2 * l2**2 + 2 * m2 * l1 * l2 * jnp.cos(th2)
        a12 = m2 * l2**2 + m2 * l1 * l2 * jnp.cos(th2)
        a22 = m2 * l2**2
        b1 = (t1 + m2 * l1 * l2 * (2 * dth1 + dth2) * dth2 * jnp.sin(th2) -
              m2 * l2 * g * jnp.sin(th1 + th2) -
              (m1 + m2) * l1 * g * jnp.sin(th1))
        b2 = t2 - m2 * l1 * l2 * dth1**2 * jnp.sin(
            th2) - m2 * l2 * g * jnp.sin(th1 + th2)
        A, b = jnp.array([[a11, a12], [a12, a22]]), jnp.array([b1, b2])
        ddth1, ddth2 = jnp.linalg.inv(A) @ b

        th1, th2 = th1 + dth1 * self.dt, th2 + dth2 * self.dt
        dth1, dth2 = dth1 + ddth1 * self.dt, dth2 + ddth2 * self.dt
        Dx, Dy = (
            l1 * jnp.cos(th1) + l2 * jnp.cos(th1 + th2) - self.goal_coord[0],
            l1 * jnp.sin(th1) + l2 * jnp.sin(th1 + th2) - self.goal_coord[1],
        )

        arr = jnp.array([th1, th2, dth1, dth2, Dx, Dy])
        new_state = state.replace(arr=arr, h=state.h + 1)

        return new_state, arr
Example #15
0
class Acrobot(Env):
    """Acrobot."""
    key: jnp.ndarray = field(jaxed=False)
    dt: float = field(0.2, jaxed=False)

    LINK_LENGTH_1: float = field(1.0, jaxed=False)
    LINK_LENGTH_2: float = field(1.0, jaxed=False)
    LINK_MASS_1: float = field(1.0, jaxed=False)
    LINK_MASS_2: float = field(1.0, jaxed=False)
    LINK_COM_POS_1: float = field(0.0, jaxed=False)
    LINK_COM_POS_2: float = field(0.0, jaxed=False)
    LINK_MOI: float = field(1.0)

    MAX_VEL_1: float = field(4 * jnp.pi, jaxed=False)
    MAX_VEL_2: float = field(9 * jnp.pi, jaxed=False)

    AVAIL_TORQUE: jnp.ndarray = field(jaxed=False)

    torque_noise_max: float = field(0.0, jaxed=False)

    high: jnp.ndarray = field(jaxed=False)
    low: jnp.ndarray = field(jaxed=False)

    def setup(self):
        self.high = jnp.array(
            [1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2])
        self.low = -self.high
        if self.key is None:
            self.key = jax.random.PRNGKey(0)
        if self.AVAIL_TORQUE is None:
            self.AVAIL_TORQUE = jnp.array([-1.0, 0.0, +1])

    def init(self):
        # TODO(dsuo): to implement
        pass

    def __call__(self, state, action):
        augmented_state = jnp.append(state, action)

        new_state = rk4(self._dsdt, augmented_state, [0, self.dt])
        # only care about final timestep of integration returned by integrator
        new_state = new_state[-1]
        new_state = new_state[:4]  # omit action
        # ODEINT IS TOO SLOW!
        # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt])
        # self.s_continuous = ns_continuous[-1] # We only care about the state
        # at the ''final timestep'', self.dt

        new_state = new_state.at[0].set(wrap(new_state[0], -jnp.pi, jnp.pi))
        new_state = new_state.at[1].set(wrap(new_state[1], -jnp.pi, jnp.pi))
        new_state = new_state.at[2].set(
            bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1))
        new_state = new_state.at[3].set(
            bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2))

        return (
            new_state,
            jnp.array([
                jnp.cos(new_state[0]),
                jnp.sin(new_state[0]),
                jnp.cos(new_state[1]),
                jnp.sin(new_state[1]),
                new_state[2],
                new_state[3],
            ]),
        )

    def _dsdt(self, augmented_state, t):
        m1 = self.LINK_MASS_1
        m2 = self.LINK_MASS_2
        l1 = self.LINK_LENGTH_1
        lc1 = self.LINK_COM_POS_1
        lc2 = self.LINK_COM_POS_2
        I1 = self.LINK_MOI
        I2 = self.LINK_MOI
        g = 9.8
        a = augmented_state[-1]
        s = augmented_state[:-1]

        theta1 = s[0]
        theta2 = s[1]
        dtheta1 = s[2]
        dtheta2 = s[3]

        d1 = m1 * lc1**2 + m2 * (l1**2 + lc2**2 +
                                 2 * l1 * lc2 * jnp.cos(theta2)) + I1 + I2
        d2 = m2 * (lc2**2 + l1 * lc2 * jnp.cos(theta2)) + I2
        phi2 = m2 * lc2 * g * jnp.cos(theta1 + theta2 - jnp.pi / 2.0)
        phi1 = (-m2 * l1 * lc2 * dtheta2**2 * jnp.sin(theta2) -
                2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * jnp.sin(theta2) +
                (m1 * lc1 + m2 * l1) * g * jnp.cos(theta1 - jnp.pi / 2) + phi2)
        ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2**2 + I2 -
                                                  d2**2 / d1)
        # if self.book_or_nips == "nips":
        # the following line is consistent with the description in the
        # paper
        # ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
        # else:
        # the following line is consistent with the java implementation and the
        # book
        # ddtheta2 = (
        # a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1 ** 2 * jnp.sin(theta2) - phi2
        # ) / (m2 * lc2 ** 2 + I2 - d2 ** 2 / d1)
        ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
        return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)  # 4-state version
Example #16
0
class PIDState(Obj):
    P: float = field(0.0, jaxed=True)
    I: float = field(0.0, jaxed=True)
    D: float = field(0.0, jaxed=True)
Example #17
0
class PendulumState(Obj):
    """PendulumState."""
    arr: jnp.ndarray = field(jaxed=True)
    h: int = field(0, jaxed=True)