Esempio n. 1
0
class LDS(Env):
    def __init__(self,
                 state_size=1,
                 action_size=1,
                 A=None,
                 B=None,
                 C=None,
                 seed=0):

        if A is not None:
            assert (
                A.shape[0] == state_size and A.shape[1] == state_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        if B is not None:
            assert (
                B.shape[0] == state_size and B.shape[1] == action_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        self.random = Random(seed)

        self.state_size, self.action_size = state_size, action_size

        self.A = (A if A is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, state_size)))

        self.B = (B if B is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, action_size)))

        self.C = (C if C is not None else jax.numpy.identity(self.state_size))

        self.t = 0

        self.reset()

    def step(self, action):
        self.state = self.A @ self.state + self.B @ action
        self.obs = self.C @ self.state

    @jax.jit
    def dynamics(self, state, action):
        new_state = self.A @ state + self.B @ action
        return new_state

    def reset(self):
        self.state = jax.random.normal(self.random.generate_key(),
                                       shape=(self.state_size, 1))
        self.obs = self.C @ self.state
Esempio n. 2
0
class Deep(Agent):
    """
    Generic deep controller that uses zero-order methods to train on an
    environment.
    """

    def __init__(
        self,
        env: Env,
        learning_rate: Real = 0.001,
        gamma: Real = 0.99,
        max_episode_length: int = 500,
        seed: int = 0,
    ) -> None:
        """
        Description: initializes the Deep agent

        Args:
            env (Env): a deluca environment
            learning_rate (Real):
            gamma (Real):
            max_episode_length (int):
            seed (int):

        Returns:
            None
        """
        # Create gym and seed numpy
        self.env = env
        self.max_episode_length = max_episode_length
        self.lr = learning_rate
        self.gamma = gamma

        self.random = Random(seed)

        self.reset()

    def reset(self) -> None:
        """
        Description: reset agent

        Args:
            None

        Returns:
            None
        """
        # Init weight
        self.W = jax.random.uniform(
            self.random.generate_key(),
            shape=(self.env.state_size, len(self.env.action_space)),
            minval=0,
            maxval=1,
        )

        # Keep stats for final print of graph
        self.episode_rewards = []

        self.current_episode_length = 0
        self.current_episode_reward = 0
        self.episode_rewards = jnp.zeros(self.max_episode_length)
        self.episode_grads = jnp.zeros((self.max_episode_length, self.W.shape[0], self.W.shape[1]))

    def policy(self, state: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Policy that maps state to action parameterized by w

        Args:
            state (jnp.ndarray):
            w (jnp.ndarray):
        """
        z = jnp.dot(state, w)
        exp = jnp.exp(z)
        return exp / jnp.sum(exp)

    def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Vectorized softmax Jacobian

        Args:
            softmax (jnp.ndarray)
        """
        s = softmax.reshape(-1, 1)
        return jnp.diagflat(s) - jnp.dot(s, s.T)

    def __call__(self, state: jnp.ndarray):
        """
        Description: provide an action given a state

        Args:
            state (jnp.ndarray):

        Returns:
            jnp.ndarray: action to take
        """
        self.state = state
        self.probs = self.policy(state, self.W)
        self.action = jax.random.choice(
            self.random.generate_key(), a=self.env.action_space, p=self.probs
        )
        return self.action

    def feed(self, reward: Real) -> None:
        """
        Description: compute gradient and save with reward in memory for weight updates

        Args:
            reward (Real):

        Returns:
            None
        """
        dsoftmax = self.softmax_grad(self.probs)[self.action, :]
        dlog = dsoftmax / self.probs[self.action]
        grad = self.state.reshape(-1, 1) @ dlog.reshape(1, -1)

        self.episode_rewards = jax.ops.index_update(
            self.episode_rewards, self.current_episode_length, reward
        )
        self.episode_grads = jax.ops.index_update(
            self.episode_grads, self.current_episode_length, grad
        )
        self.current_episode_length += 1

    def update(self) -> None:
        """
        Description: update weights
        
        Args:
            None

        Returns:
            None
        """
        for i in range(self.current_episode_length):
            # Loop through everything that happend in the episode and update
            # towards the log policy gradient times **FUTURE** reward
            self.W += self.lr * self.episode_grads[i] + jnp.sum(
                jnp.array(
                    [
                        r * (self.gamma ** r)
                        for r in self.episode_rewards[i : self.current_episode_length]
                    ]
                )
            )

        # reset episode length
        self.current_episode_length = 0
Esempio n. 3
0
class MountainCar(Env):
    def __init__(self, goal_velocity=0, seed=0):
        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2
        self.max_position = 0.6
        self.max_speed = 0.07
        self.goal_position = 0.45  # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
        self.goal_velocity = goal_velocity
        self.power = 0.0015
        self.random = Random(seed)

        self.low_state = np.array([self.min_position, -self.max_speed],
                                  dtype=np.float32)
        self.high_state = np.array([self.max_position, self.max_speed],
                                   dtype=np.float32)

        self.action_space = spaces.Box(low=self.min_action,
                                       high=self.max_action,
                                       shape=(1, ),
                                       dtype=np.float32)
        self.observation_space = spaces.Box(low=self.low_state,
                                            high=self.high_state,
                                            dtype=np.float32)

        self.reset()

    @jax.jit
    def dynamics(self, state, action):
        position = state[0]
        velocity = state[1]

        force = jnp.minimum(jnp.maximum(action, self.min_action),
                            self.max_action)

        velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
        velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

        position += velocity
        position = jnp.clip(position, self.min_position, self.max_position)

        reset_velocity = (position == self.min_position) & (velocity < 0)
        velocity = jax.lax.cond(reset_velocity == 1, lambda x: 0.0,
                                lambda x: x, velocity)
        # if (position == self.min_position and velocity < 0): velocity = 0

        return jnp.array([position, velocity])

    def step(self, action):

        self.state = self.dynamics(self.state, action)
        position = self.state[0]
        velocity = self.state[1]

        # Convert a possible numpy bool to a Python bool.
        done = (position >= self.goal_position) & (velocity >=
                                                   self.goal_velocity)

        reward = 100.0 * done
        reward -= jnp.power(action, 2) * 0.1

        return self.state, reward, done, {}

    def reset(self):
        self.state = jnp.array([
            jax.random.uniform(self.random.generate_key(),
                               minval=-0.6,
                               maxval=0.4), 0
        ])
        return self.state

    def _height(self, xs):
        return jnp.sin(3 * xs) * 0.45 + 0.55

    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.max_position - self.min_position
        scale = screen_width / world_width
        carwidth = 40
        carheight = 20

        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(screen_width, screen_height)
            xs = np.linspace(self.min_position, self.max_position, 100)
            ys = self._height(xs)
            xys = list(zip((xs - self.min_position) * scale, ys * scale))

            self.track = rendering.make_polyline(xys)
            self.track.set_linewidth(4)
            self.viewer.add_geom(self.track)

            clearance = 10

            l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0
            car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            car.add_attr(rendering.Transform(translation=(0, clearance)))
            self.cartrans = rendering.Transform()
            car.add_attr(self.cartrans)
            self.viewer.add_geom(car)
            frontwheel = rendering.make_circle(carheight / 2.5)
            frontwheel.set_color(0.5, 0.5, 0.5)
            frontwheel.add_attr(
                rendering.Transform(translation=(carwidth / 4, clearance)))
            frontwheel.add_attr(self.cartrans)
            self.viewer.add_geom(frontwheel)
            backwheel = rendering.make_circle(carheight / 2.5)
            backwheel.add_attr(
                rendering.Transform(translation=(-carwidth / 4, clearance)))
            backwheel.add_attr(self.cartrans)
            backwheel.set_color(0.5, 0.5, 0.5)
            self.viewer.add_geom(backwheel)
            flagx = (self.goal_position - self.min_position) * scale
            flagy1 = self._height(self.goal_position) * scale
            flagy2 = flagy1 + 50
            flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
            self.viewer.add_geom(flagpole)
            flag = rendering.FilledPolygon([(flagx, flagy2),
                                            (flagx, flagy2 - 10),
                                            (flagx + 25, flagy2 - 5)])
            flag.set_color(0.8, 0.8, 0)
            self.viewer.add_geom(flag)

        pos = self.state[0]
        self.cartrans.set_translation((pos - self.min_position) * scale,
                                      self._height(pos) * scale)
        self.cartrans.set_rotation(math.cos(3 * pos))

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Esempio n. 4
0
class Acrobot(Env):
    """
    Acrobot is a 2-link pendulum with only the second joint actuated.
    Initially, both links point downwards. The goal is to swing the
    end-effector at a height at least the length of one link above the base.
    Both links can swing freely and can pass by each other, i.e., they don't
    collide when they have the same angle.
    **STATE:**
    The state consists of the sin() and cos() of the two rotational joint
    angles and the joint angular velocities :
    [cos(theta1) sin(theta1) cos(theta2) sin(theta2) thetaDot1 thetaDot2].
    For the first link, an angle of 0 corresponds to the link pointing downwards.
    The angle of the second link is relative to the angle of the first link.
    An angle of 0 corresponds to having the same angle between the two links.
    A state of [1, 0, 1, 0, ..., ...] means that both links point downwards.
    **ACTIONS:**
    The action is either applying +1, 0 or -1 torque on the joint between
    the two pendulum links.
    **REFERENCE:**
    .. warning::
        This version of the domain uses the Runge-Kutta method for integrating
        the system dynamics and is more realistic, but also considerably harder
        than the original version which employs Euler integration,
        see the AcrobotLegacy class.
    """

    dt = 0.2

    LINK_LENGTH_1 = 1.0  # [m]
    LINK_LENGTH_2 = 1.0  # [m]
    LINK_MASS_1 = 1.0  #: [kg] mass of link 1
    LINK_MASS_2 = 1.0  #: [kg] mass of link 2
    LINK_COM_POS_1 = 0.5  #: [m] position of the center of mass of link 1
    LINK_COM_POS_2 = 0.5  #: [m] position of the center of mass of link 2
    LINK_MOI = 1.0  #: moments of inertia for both links

    MAX_VEL_1 = 4 * pi
    MAX_VEL_2 = 9 * pi

    AVAIL_TORQUE = jnp.array([-1.0, 0.0, +1])

    torque_noise_max = 0.0

    #: use dynamics equations from the nips paper or the book
    book_or_nips = "book"
    action_arrow = None
    domain_fig = None
    actions_num = 3

    def __init__(self, seed=0, horizon=50):
        high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2],
                        dtype=np.float32)
        low = -high
        self.random = Random(seed)
        self.observation_space = spaces.Box(low=low,
                                            high=high,
                                            dtype=np.float32)
        self.action_space = spaces.Discrete(3)
        self.action_dim = 1
        self.H = horizon
        self.nsamples = 0

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            # Augment the state with our force action so it can be passed to _dsdt
            augmented_state = jnp.append(state, action)

            new_state = rk4(self._dsdt, augmented_state, [0, self.dt])
            # only care about final timestep of integration returned by integrator
            new_state = new_state[-1]
            new_state = new_state[:4]  # omit action
            # ODEINT IS TOO SLOW!
            # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt])
            # self.s_continuous = ns_continuous[-1] # We only care about the state
            # at the ''final timestep'', self.dt

            new_state = jax.ops.index_update(new_state, 0,
                                             wrap(new_state[0], -pi, pi))
            new_state = jax.ops.index_update(new_state, 1,
                                             wrap(new_state[1], -pi, pi))
            new_state = jax.ops.index_update(
                new_state, 2,
                bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1))
            new_state = jax.ops.index_update(
                new_state, 3,
                bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2))

            return new_state

        # @jax.jit
        def c(x, u):
            return u[0]**2 + 1 - jnp.exp(0.5 * cos(x[0]) + 0.5 * cos(x[1]) - 1)

        self.dynamics = _dynamics
        self.reward_fn = c

        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )
        self.reset()

    def reset(self):
        self.state = jax.random.uniform(
            self.random.generate_key(),
            shape=(4, ),
            minval=-0.1,
            maxval=0.1
            # self.random.generate_key(), shape=(6,), minval=-0.1, maxval=0.1
        )
        return self.state

    def step(self, action):
        # torque = self.AVAIL_TORQUE[action] # discrete action space
        torque = bound(action, -1.0, 1.0)  # continuous action space

        # Add noise to the force action
        if self.torque_noise_max > 0:
            torque += self.np_random.uniform(-self.torque_noise_max,
                                             self.torque_noise_max)
        self.state = self.dynamics(self.state, torque)
        terminal = self._terminal()
        # reward = -1.0 + terminal # openAI cost function
        reward = self.reward_fn(self.state, action)

        # TODO: should this return self.state (dim 4) or self.observation (dim 6)?
        return self.state, reward, terminal, {}

    @property
    def observation(self):
        return jnp.array([
            cos(self.state[0]),
            sin(self.state[0]),
            cos(self.state[1]),
            sin(self.state[1]),
            self.state[2],
            self.state[3],
        ])

    def _terminal(self):
        return -cos(self.state[0]) - cos(self.state[1] + self.state[0]) > 1.0

    def _dsdt(self, augmented_state, t):
        m1 = self.LINK_MASS_1
        m2 = self.LINK_MASS_2
        l1 = self.LINK_LENGTH_1
        lc1 = self.LINK_COM_POS_1
        lc2 = self.LINK_COM_POS_2
        I1 = self.LINK_MOI
        I2 = self.LINK_MOI
        g = 9.8
        a = augmented_state[-1]
        s = augmented_state[:-1]

        theta1 = s[0]
        theta2 = s[1]
        dtheta1 = s[2]
        dtheta2 = s[3]

        d1 = m1 * lc1**2 + m2 * (l1**2 + lc2**2 +
                                 2 * l1 * lc2 * cos(theta2)) + I1 + I2
        d2 = m2 * (lc2**2 + l1 * lc2 * cos(theta2)) + I2
        phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
        phi1 = (-m2 * l1 * lc2 * dtheta2**2 * sin(theta2) -
                2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * sin(theta2) +
                (m1 * lc1 + m2 * l1) * g * cos(theta1 - pi / 2) + phi2)
        if self.book_or_nips == "nips":
            # the following line is consistent with the description in the
            # paper
            ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2**2 + I2 -
                                                      d2**2 / d1)
        else:
            # the following line is consistent with the java implementation and the
            # book
            ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1**2 *
                        sin(theta2) - phi2) / (m2 * lc2**2 + I2 - d2**2 / d1)
        ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
        return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)  # 4-state version
Esempio n. 5
0
class LDS(Env):
    def __init__(self,
                 state_size=1,
                 action_size=1,
                 A=None,
                 B=None,
                 C=None,
                 seed=0,
                 noise="normal"):

        if A is not None:
            assert (
                A.shape[0] == state_size and A.shape[1] == state_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        if B is not None:
            assert (
                B.shape[0] == state_size and B.shape[1] == action_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        self.viewer = None
        self.random = Random(seed)
        self.noise = noise

        self.state_size, self.action_size = state_size, action_size
        self.proj_vector1 = jax.random.normal(self.random.generate_key(),
                                              (state_size, ))
        self.proj_vector1 = self.proj_vector1 / jnp.linalg.norm(
            self.proj_vector1)
        self.proj_vector2 = jax.random.normal(self.random.generate_key(),
                                              (state_size, ))
        self.proj_vector2 = self.proj_vector2 / jnp.linalg.norm(
            (self.proj_vector2))
        print('self.proj_vector1:' + str(self.proj_vector1))
        print('self.proj_vector2:' + str(self.proj_vector2))

        self.A = (A if A is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, state_size)))

        self.B = (B if B is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, action_size)))

        self.C = (C if C is not None else jax.numpy.identity(self.state_size))

        self.t = 0

        self.reset()

    def step(self, action):
        self.state = self.A @ self.state + self.B @ action
        if self.noise == "normal":
            self.state += np.random.normal(0,
                                           0.1,
                                           size=(self.state_size,
                                                 self.action_size))
        self.obs = self.C @ self.state

    @jax.jit
    def dynamics(self, state, action):
        new_state = self.A @ state + self.B @ action
        return new_state

    def reset(self):
        self.state = jax.random.normal(self.random.generate_key(),
                                       shape=(self.state_size, 1))
        self.obs = self.C @ self.state

    def render(self, mode="human"):
        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(1000, 1000)
            self.viewer.set_bounds(-2.5, 2.5, -2.5, 2.5)
            fname = path.dirname(__file__) + "/classic/assets/lds_arrow.png"
            self.img = rendering.Image(fname, 0.35, 0.35)
            self.img.set_color(1.0, 1.0, 1.0)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)
            fnamewind = path.dirname(__file__) + "/classic/assets/lds_grid.png"
            self.imgwind = rendering.Image(fnamewind, 5.0, 5.0)
            self.imgwind.set_color(0.5, 0.5, 0.5)
            self.imgtranswind = rendering.Transform()
            self.imgwind.add_attr(self.imgtranswind)

        self.viewer.add_onetime(self.imgwind)
        self.viewer.add_onetime(self.img)
        cur_x, cur_y = self.imgtrans.translation
        # new_x, new_y = self.state[0], self.state[1]
        new_x, new_y = jnp.dot(self.state.squeeze(),
                               self.proj_vector1), jnp.dot(
                                   self.state.squeeze(), self.proj_vector2)
        diff_x = new_x - cur_x
        diff_y = new_y - cur_y
        new_rotation = jnp.arctan2(diff_y, diff_x)
        self.imgtrans.set_translation(new_x, new_y)
        self.imgtrans.set_rotation(new_rotation)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None
Esempio n. 6
0
class Pendulum(Env):
    max_speed = 8.0
    max_torque = 2.0  # gym 2.
    high = np.array([1.0, 1.0, max_speed])

    action_space = spaces.Box(low=-max_torque,
                              high=max_torque,
                              shape=(1, ),
                              dtype=np.float32)
    observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
    metadata = {'render.modes': ['human', 'rgb_array']}

    def __init__(self, reward_fn=None, seed=0, horizon=50):
        # self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1
        self.action_dim = 1  # redundant with action_size but needed by ILQR

        self.H = horizon

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize
        self.nsamples = 0
        self.last_u = None
        self.random = Random(seed)

        self.reset()

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            self.last_u = action
            th, thdot = state
            g = 10.0
            m = 1.0
            ell = 1.0
            dt = self.dt

            # Do not limit the control signals
            action = jnp.clip(action, -self.max_torque, self.max_torque)

            newthdot = (thdot +
                        (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 /
                         (m * ell**2) * action) * dt)
            newth = th + newthdot * dt
            newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

            return jnp.reshape(jnp.array([newth, newthdot]), (2, ))

        @jax.jit
        def c(x, u):
            # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
            return angle_normalize(x[0])**2 + .1 * (u[0]**2)

        self.reward_fn = reward_fn or c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )

    def reset(self):
        th = jax.random.uniform(self.random.generate_key(),
                                minval=-jnp.pi,
                                maxval=jnp.pi)
        thdot = jax.random.uniform(self.random.generate_key(),
                                   minval=-1.0,
                                   maxval=1.0)

        self.state = jnp.array([th, thdot])

        return self.state

    def render(self, mode="human"):
        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(500, 500)
            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
            rod = rendering.make_capsule(1, 0.2)
            rod.set_color(0.8, 0.3, 0.3)
            self.pole_transform = rendering.Transform()
            rod.add_attr(self.pole_transform)
            self.viewer.add_geom(rod)
            axle = rendering.make_circle(0.05)
            axle.set_color(0, 0, 0)
            self.viewer.add_geom(axle)
            fname = path.join(path.dirname(__file__), "assets/clockwise.png")
            self.img = rendering.Image(fname, 1.0, 1.0)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)

        self.viewer.add_onetime(self.img)
        self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
        if self.last_u:
            self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Esempio n. 7
0
class DRC(Agent):
    def __init__(
        self,
        A: jnp.ndarray,
        B: jnp.ndarray,
        C: jnp.ndarray = None,
        K: jnp.ndarray = None,
        cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None,
        m: int = 10,
        h: int = 50,
        lr_scale: Real = 0.03,
        decay: bool = True,
        RM: int = 1000,
        seed: int = 0,
    ) -> None:
        """
        Description: Initialize the dynamics of the model.

        Args:
            A (jnp.ndarray): system dynamics
            B (jnp.ndarray): system dynamics
            C (jnp.ndarray): system dynamics
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
            start_time (int):
            cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]):
            H (postive int): history of the controller
            HH (positive int): history of the system
            lr_scale (Real):
            decay (boolean):
            seed (int):
        """

        cost_fn = cost_fn or quad_loss

        self.random = Random(seed)

        d_state, d_action = B.shape  # State & Action Dimensions

        C = jnp.identity(d_state) if C is None else C

        d_obs = C.shape[0]  # Observation Dimension

        self.t = 0  # Time Counter (for decaying learning rate)

        self.m, self.h = m, h

        self.lr_scale, self.decay = lr_scale, decay

        self.RM = RM

        # Construct truncated markov operator G
        self.G = jnp.zeros((h, d_obs, d_action))
        A_power = jnp.identity(d_state)
        for i in range(h):
            self.G = self.G.at[i].set(C @ A_power @ B)
            A_power = A_power @ A

        # Model Parameters
        # initial linear policy / perturbation contributions / bias
        self.K = K if K is not None else jnp.zeros((d_action, d_obs))

        self.M = lr_scale * jax.random.normal(
            self.random.generate_key(), shape=(m, d_action, d_obs)
        )

        # Past m nature y's such that y_nat[0] is the most recent
        self.y_nat = jnp.zeros((m, d_obs, 1))

        # Past h u's such that u[0] is the most recent
        self.us = jnp.zeros((h, d_action, 1))

        def policy_loss(M, G, y_nat, us):
            """Surrogate cost function"""

            def action(obs):
                """Action function"""
                return -self.K @ obs + jnp.tensordot(M, y_nat, axes=([0, 2], [0, 1]))

            final_state = y_nat[0] + jnp.tensordot(G, us, axes=([0, 2], [0, 1]))
            return cost_fn(final_state, action(final_state))

        self.policy_loss = policy_loss
        self.grad = jit(grad(policy_loss))

    def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Return the action based on current state and internal parameters.

        Args:
            state (jnp.ndarray): current state

        Returns:
           jnp.ndarray: action to take
        """
        # update y_nat
        self.update_noise(obs)

        # get action
        action = self.get_action(obs)

        # update Parameters
        self.update_params(obs, action)

        return action

    def get_action(self, obs: jnp.ndarray) -> jnp.ndarray:
        """
        Description: get action from state.

        Args:
            state (jnp.ndarray):

        Returns:
            jnp.ndarray
        """

        return -self.K @ obs + jnp.tensordot(self.M, self.y_nat, axes=([0, 2], [0, 1]))

    def update(self, obs: jnp.ndarray, u: jnp.ndarray) -> None:
        self.update_noise(obs)
        self.update_params(obs, u)

    def update_noise(self, obs: jnp.ndarray) -> None:
        y_nat = obs - jnp.tensordot(self.G, self.us, axes=([0, 2], [0, 1]))
        self.y_nat = jnp.roll(self.y_nat, 1, axis=0)
        self.y_nat = self.y_nat.at[0].set(y_nat)

    def update_params(self, obs: jnp.ndarray, u: jnp.ndarray) -> None:
        """
        Description: update agent internal state.

        Args:
            state (jnp.ndarray):

        Returns:
            None
        """

        # update parameters
        delta_M = self.grad(self.M, self.G, self.y_nat, self.us)
        lr = self.lr_scale
        # lr *= (1/ (self.t+1)) if self.decay else 1
        lr = jax.lax.cond(self.decay, lambda x: x * 1 / (self.t + 1), lambda x: 1.0, lr)

        self.M -= lr * delta_M
        # if(jnp.linalg.norm(self.M) > self.RM):
        #     self.M *= (self.RM / jnp.linalg.norm(self.M))

        self.M = jax.lax.cond(
            jnp.linalg.norm(self.M) > self.RM,
            lambda x: x * (self.RM / jnp.linalg.norm(self.M)),
            lambda x: x,
            self.M,
        )

        # update us
        self.us = jnp.roll(self.us, 1, axis=0)
        self.us = self.us.at[0].set(u)

        self.t += 1
Esempio n. 8
0
class MountainCar(Env):
    metadata = {'render.modes': ['human', 'rgb_array']}
    def __init__(self, goal_velocity=0, seed=0, horizon=50):
        self.viewer = None
        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2
        self.max_position = 0.6
        self.max_speed = 0.07
        self.goal_position = 0.45  # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
        self.goal_velocity = goal_velocity
        self.power = 0.0015
        self.H = horizon
        self.action_dim = 1
        self.random = Random(seed)

        self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
        self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)

        self.action_space = spaces.Box(
            low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32
        )
        self.observation_space = spaces.Box(
            low=self.low_state, high=self.high_state, dtype=np.float32
        )
        self.nsamples = 0
        
        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            position = state[0]
            velocity = state[1]

            force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action)

            velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
            velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

            position += velocity
            position = jnp.clip(position, self.min_position, self.max_position)
            reset_velocity = (position == self.min_position) & (velocity < 0)
            # print('state.shape = ' + str(state.shape))
            # print('position.shape = ' + str(position.shape))
            # print('velocity.shape = ' + str(velocity.shape))
            # print('reset_velocity.shape = ' + str(reset_velocity.shape))
            velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros((1,)), velocity, lambda x: x)
            # print('velocity.shape AFTER = ' + str(velocity.shape))
            return jnp.reshape(jnp.array([position, velocity]), (2,))
        
        @jax.jit
        def c(x, u):
            position, velocity = self.state[0], self.state[1]
            done = (position >= self.goal_position) & (velocity >= self.goal_velocity)
            return -100.0 * done + 0.1*(u[0]+1)**2
        self.reward_fn = c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
                _dynamics,
                jax.jacfwd(_dynamics, argnums=0),
                jax.jacfwd(_dynamics, argnums=1),
            )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
                c,
                jax.grad(c, argnums=0),
                jax.grad(c, argnums=1),
                jax.hessian(c, argnums=0),
                jax.hessian(c, argnums=1),
            )
                           

        self.reset()

    def step(self, action):
        self.state = self.dynamics(self.state, action)
        position = self.state[0]
        velocity = self.state[1]

        # Convert a possible numpy bool to a Python bool.
        done = (position >= self.goal_position) & (velocity >= self.goal_velocity)

        # reward = 100.0 * done
        # reward -= jnp.power(action, 2) * 0.1
        reward = self.reward_fn(self.state, action)

        return self.state, reward, done, {}

    def reset(self):
        self.state = jnp.array(
            [jax.random.uniform(self.random.generate_key(), minval=-0.6, maxval=0.4), 0]
        )
        return self.state

    def _height(self, xs):
        return jnp.sin(3 * xs) * 0.45 + 0.55

    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.max_position - self.min_position
        scale = screen_width / world_width
        carwidth = 40
        carheight = 20

        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(screen_width, screen_height)
            xs = np.linspace(self.min_position, self.max_position, 100)
            ys = self._height(xs)
            xys = list(zip((xs - self.min_position) * scale, ys * scale))

            self.track = rendering.make_polyline(xys)
            self.track.set_linewidth(4)
            self.viewer.add_geom(self.track)

            clearance = 10

            l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0
            car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            car.add_attr(rendering.Transform(translation=(0, clearance)))
            self.cartrans = rendering.Transform()
            car.add_attr(self.cartrans)
            self.viewer.add_geom(car)
            frontwheel = rendering.make_circle(carheight / 2.5)
            frontwheel.set_color(0.5, 0.5, 0.5)
            frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
            frontwheel.add_attr(self.cartrans)
            self.viewer.add_geom(frontwheel)
            backwheel = rendering.make_circle(carheight / 2.5)
            backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
            backwheel.add_attr(self.cartrans)
            backwheel.set_color(0.5, 0.5, 0.5)
            self.viewer.add_geom(backwheel)
            flagx = (self.goal_position - self.min_position) * scale
            flagy1 = self._height(self.goal_position) * scale
            flagy2 = flagy1 + 50
            flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
            self.viewer.add_geom(flagpole)
            flag = rendering.FilledPolygon(
                [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
            )
            flag.set_color(0.8, 0.8, 0)
            self.viewer.add_geom(flag)

        pos = self.state[0]
        self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
        self.cartrans.set_rotation(math.cos(3 * pos))

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Esempio n. 9
0
class Pendulum(Env):
    max_speed = 8.0
    max_torque = 2.0  # gym 2.
    high = np.array([1.0, 1.0, max_speed])

    action_space = gym.spaces.Box(low=-max_torque, high=max_torque, shape=(1,), dtype=np.float32)
    observation_space = gym.spaces.Box(low=-high, high=high, dtype=np.float32)

    def __init__(self, reward_fn=None, seed=0):
        self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize

        self.random = Random(seed)

        self.reset()

    @jax.jit
    def dynamics(self, state, action):
        th, thdot = state
        g = 10.0
        m = 1.0
        ell = 1.0
        dt = self.dt

        # Do not limit the control signals
        action = jnp.clip(action, -self.max_torque, self.max_torque)

        newthdot = (
            thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell ** 2) * action) * dt
        )
        newth = th + newthdot * dt
        newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

        return jnp.array([newth, newthdot])

    def reset(self):
        th = jax.random.uniform(self.random.generate_key(), minval=-jnp.pi, maxval=jnp.pi)
        thdot = jax.random.uniform(self.random.generate_key(), minval=-1.0, maxval=1.0)

        self.state = jnp.array([th, thdot])

        return self.state

    def render(self, mode="human"):
        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(500, 500)
            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
            rod = rendering.make_capsule(1, 0.2)
            rod.set_color(0.8, 0.3, 0.3)
            self.pole_transform = rendering.Transform()
            rod.add_attr(self.pole_transform)
            self.viewer.add_geom(rod)
            axle = rendering.make_circle(0.05)
            axle.set_color(0, 0, 0)
            self.viewer.add_geom(axle)
            fname = path.join(path.dirname(__file__), "assets/clockwise.png")
            self.img = rendering.Image(fname, 1.0, 1.0)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)

        self.viewer.add_onetime(self.img)
        self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
        if self.last_u:
            self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")