Exemple #1
0
    def __init__(self, goal_velocity=0, seed=0):
        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2
        self.max_position = 0.6
        self.max_speed = 0.07
        self.goal_position = 0.45  # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
        self.goal_velocity = goal_velocity
        self.power = 0.0015
        self.random = Random(seed)

        self.low_state = np.array([self.min_position, -self.max_speed],
                                  dtype=np.float32)
        self.high_state = np.array([self.max_position, self.max_speed],
                                   dtype=np.float32)

        self.action_space = spaces.Box(low=self.min_action,
                                       high=self.max_action,
                                       shape=(1, ),
                                       dtype=np.float32)
        self.observation_space = spaces.Box(low=self.low_state,
                                            high=self.high_state,
                                            dtype=np.float32)

        self.reset()
Exemple #2
0
    def __init__(
        self,
        env: Env,
        learning_rate: Real = 0.001,
        gamma: Real = 0.99,
        max_episode_length: int = 500,
        seed: int = 0,
    ) -> None:
        """
        Description: initializes the Deep agent

        Args:
            env (Env): a deluca environment
            learning_rate (Real):
            gamma (Real):
            max_episode_length (int):
            seed (int):

        Returns:
            None
        """
        # Create gym and seed numpy
        self.env = env
        self.max_episode_length = max_episode_length
        self.lr = learning_rate
        self.gamma = gamma

        self.random = Random(seed)

        self.reset()
Exemple #3
0
    def __init__(self,
                 state_size=1,
                 action_size=1,
                 A=None,
                 B=None,
                 C=None,
                 seed=0):

        if A is not None:
            assert (
                A.shape[0] == state_size and A.shape[1] == state_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        if B is not None:
            assert (
                B.shape[0] == state_size and B.shape[1] == action_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        self.random = Random(seed)

        self.state_size, self.action_size = state_size, action_size

        self.A = (A if A is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, state_size)))

        self.B = (B if B is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, action_size)))

        self.C = (C if C is not None else jax.numpy.identity(self.state_size))

        self.t = 0

        self.reset()
Exemple #4
0
    def __init__(self, reward_fn=None, seed=0, horizon=50):
        # self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1
        self.action_dim = 1  # redundant with action_size but needed by ILQR

        self.H = horizon

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize
        self.nsamples = 0
        self.last_u = None
        self.random = Random(seed)

        self.reset()

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            self.last_u = action
            th, thdot = state
            g = 10.0
            m = 1.0
            ell = 1.0
            dt = self.dt

            # Do not limit the control signals
            action = jnp.clip(action, -self.max_torque, self.max_torque)

            newthdot = (thdot +
                        (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 /
                         (m * ell**2) * action) * dt)
            newth = th + newthdot * dt
            newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

            return jnp.reshape(jnp.array([newth, newthdot]), (2, ))

        @jax.jit
        def c(x, u):
            # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
            return angle_normalize(x[0])**2 + .1 * (u[0]**2)

        self.reward_fn = reward_fn or c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )
Exemple #5
0
    def __init__(self, seed=0):
        high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2],
                        dtype=np.float32)
        low = -high
        self.random = Random(seed)
        self.observation_space = spaces.Box(low=low,
                                            high=high,
                                            dtype=np.float32)
        self.action_space = spaces.Discrete(3)

        self.reset()
Exemple #6
0
    def __init__(self, reward_fn=None, seed=0):
        self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize

        self.random = Random(seed)

        self.reset()
Exemple #7
0
class LDS(Env):
    def __init__(self,
                 state_size=1,
                 action_size=1,
                 A=None,
                 B=None,
                 C=None,
                 seed=0):

        if A is not None:
            assert (
                A.shape[0] == state_size and A.shape[1] == state_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        if B is not None:
            assert (
                B.shape[0] == state_size and B.shape[1] == action_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        self.random = Random(seed)

        self.state_size, self.action_size = state_size, action_size

        self.A = (A if A is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, state_size)))

        self.B = (B if B is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, action_size)))

        self.C = (C if C is not None else jax.numpy.identity(self.state_size))

        self.t = 0

        self.reset()

    def step(self, action):
        self.state = self.A @ self.state + self.B @ action
        self.obs = self.C @ self.state

    @jax.jit
    def dynamics(self, state, action):
        new_state = self.A @ state + self.B @ action
        return new_state

    def reset(self):
        self.state = jax.random.normal(self.random.generate_key(),
                                       shape=(self.state_size, 1))
        self.obs = self.C @ self.state
Exemple #8
0
    def __init__(self,
                 state_size=1,
                 action_size=1,
                 A=None,
                 B=None,
                 C=None,
                 seed=0,
                 noise="normal"):

        if A is not None:
            assert (
                A.shape[0] == state_size and A.shape[1] == state_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        if B is not None:
            assert (
                B.shape[0] == state_size and B.shape[1] == action_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        self.viewer = None
        self.random = Random(seed)
        self.noise = noise

        self.state_size, self.action_size = state_size, action_size
        self.proj_vector1 = jax.random.normal(self.random.generate_key(),
                                              (state_size, ))
        self.proj_vector1 = self.proj_vector1 / jnp.linalg.norm(
            self.proj_vector1)
        self.proj_vector2 = jax.random.normal(self.random.generate_key(),
                                              (state_size, ))
        self.proj_vector2 = self.proj_vector2 / jnp.linalg.norm(
            (self.proj_vector2))
        print('self.proj_vector1:' + str(self.proj_vector1))
        print('self.proj_vector2:' + str(self.proj_vector2))

        self.A = (A if A is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, state_size)))

        self.B = (B if B is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, action_size)))

        self.C = (C if C is not None else jax.numpy.identity(self.state_size))

        self.t = 0

        self.reset()
Exemple #9
0
    def __init__(self, reward_fn=None, seed=0):
        self.viewer = None
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # actually half the pole's length
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # seconds between state updates
        self.kinematics_integrator = "euler"

        # Angle at which to fail the episode
        # Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        self.random = Random(seed)

        # Angle limit set to 2 * theta_threshold_radians so failing observation
        # is still within bounds.
        high = np.array(
            [
                self.x_threshold * 2,
                np.finfo(np.float32).max,
                self.theta_threshold_radians * 2,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        # self.action_space = jnp.array([0, 1])
        self.action_space = gym.spaces.Box(low=0, high=1, shape=(1, ))
        # TODO: no longer use gym.spaces
        self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32)

        self.state_size, self.action_size = 4, 1
        self.observation_size = self.state_size

        self.reset()
Exemple #10
0
class MountainCar(Env):
    def __init__(self, goal_velocity=0, seed=0):
        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2
        self.max_position = 0.6
        self.max_speed = 0.07
        self.goal_position = 0.45  # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
        self.goal_velocity = goal_velocity
        self.power = 0.0015
        self.random = Random(seed)

        self.low_state = np.array([self.min_position, -self.max_speed],
                                  dtype=np.float32)
        self.high_state = np.array([self.max_position, self.max_speed],
                                   dtype=np.float32)

        self.action_space = spaces.Box(low=self.min_action,
                                       high=self.max_action,
                                       shape=(1, ),
                                       dtype=np.float32)
        self.observation_space = spaces.Box(low=self.low_state,
                                            high=self.high_state,
                                            dtype=np.float32)

        self.reset()

    @jax.jit
    def dynamics(self, state, action):
        position = state[0]
        velocity = state[1]

        force = jnp.minimum(jnp.maximum(action, self.min_action),
                            self.max_action)

        velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
        velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

        position += velocity
        position = jnp.clip(position, self.min_position, self.max_position)

        reset_velocity = (position == self.min_position) & (velocity < 0)
        velocity = jax.lax.cond(reset_velocity == 1, lambda x: 0.0,
                                lambda x: x, velocity)
        # if (position == self.min_position and velocity < 0): velocity = 0

        return jnp.array([position, velocity])

    def step(self, action):

        self.state = self.dynamics(self.state, action)
        position = self.state[0]
        velocity = self.state[1]

        # Convert a possible numpy bool to a Python bool.
        done = (position >= self.goal_position) & (velocity >=
                                                   self.goal_velocity)

        reward = 100.0 * done
        reward -= jnp.power(action, 2) * 0.1

        return self.state, reward, done, {}

    def reset(self):
        self.state = jnp.array([
            jax.random.uniform(self.random.generate_key(),
                               minval=-0.6,
                               maxval=0.4), 0
        ])
        return self.state

    def _height(self, xs):
        return jnp.sin(3 * xs) * 0.45 + 0.55

    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.max_position - self.min_position
        scale = screen_width / world_width
        carwidth = 40
        carheight = 20

        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(screen_width, screen_height)
            xs = np.linspace(self.min_position, self.max_position, 100)
            ys = self._height(xs)
            xys = list(zip((xs - self.min_position) * scale, ys * scale))

            self.track = rendering.make_polyline(xys)
            self.track.set_linewidth(4)
            self.viewer.add_geom(self.track)

            clearance = 10

            l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0
            car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            car.add_attr(rendering.Transform(translation=(0, clearance)))
            self.cartrans = rendering.Transform()
            car.add_attr(self.cartrans)
            self.viewer.add_geom(car)
            frontwheel = rendering.make_circle(carheight / 2.5)
            frontwheel.set_color(0.5, 0.5, 0.5)
            frontwheel.add_attr(
                rendering.Transform(translation=(carwidth / 4, clearance)))
            frontwheel.add_attr(self.cartrans)
            self.viewer.add_geom(frontwheel)
            backwheel = rendering.make_circle(carheight / 2.5)
            backwheel.add_attr(
                rendering.Transform(translation=(-carwidth / 4, clearance)))
            backwheel.add_attr(self.cartrans)
            backwheel.set_color(0.5, 0.5, 0.5)
            self.viewer.add_geom(backwheel)
            flagx = (self.goal_position - self.min_position) * scale
            flagy1 = self._height(self.goal_position) * scale
            flagy2 = flagy1 + 50
            flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
            self.viewer.add_geom(flagpole)
            flag = rendering.FilledPolygon([(flagx, flagy2),
                                            (flagx, flagy2 - 10),
                                            (flagx + 25, flagy2 - 5)])
            flag.set_color(0.8, 0.8, 0)
            self.viewer.add_geom(flag)

        pos = self.state[0]
        self.cartrans.set_translation((pos - self.min_position) * scale,
                                      self._height(pos) * scale)
        self.cartrans.set_rotation(math.cos(3 * pos))

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Exemple #11
0
    def __init__(self, seed=0, horizon=50):
        high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2],
                        dtype=np.float32)
        low = -high
        self.random = Random(seed)
        self.observation_space = spaces.Box(low=low,
                                            high=high,
                                            dtype=np.float32)
        self.action_space = spaces.Discrete(3)
        self.action_dim = 1
        self.H = horizon
        self.nsamples = 0

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            # Augment the state with our force action so it can be passed to _dsdt
            augmented_state = jnp.append(state, action)

            new_state = rk4(self._dsdt, augmented_state, [0, self.dt])
            # only care about final timestep of integration returned by integrator
            new_state = new_state[-1]
            new_state = new_state[:4]  # omit action
            # ODEINT IS TOO SLOW!
            # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt])
            # self.s_continuous = ns_continuous[-1] # We only care about the state
            # at the ''final timestep'', self.dt

            new_state = jax.ops.index_update(new_state, 0,
                                             wrap(new_state[0], -pi, pi))
            new_state = jax.ops.index_update(new_state, 1,
                                             wrap(new_state[1], -pi, pi))
            new_state = jax.ops.index_update(
                new_state, 2,
                bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1))
            new_state = jax.ops.index_update(
                new_state, 3,
                bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2))

            return new_state

        # @jax.jit
        def c(x, u):
            return u[0]**2 + 1 - jnp.exp(0.5 * cos(x[0]) + 0.5 * cos(x[1]) - 1)

        self.dynamics = _dynamics
        self.reward_fn = c

        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )
        self.reset()
Exemple #12
0
class Acrobot(Env):
    """
    Acrobot is a 2-link pendulum with only the second joint actuated.
    Initially, both links point downwards. The goal is to swing the
    end-effector at a height at least the length of one link above the base.
    Both links can swing freely and can pass by each other, i.e., they don't
    collide when they have the same angle.
    **STATE:**
    The state consists of the sin() and cos() of the two rotational joint
    angles and the joint angular velocities :
    [cos(theta1) sin(theta1) cos(theta2) sin(theta2) thetaDot1 thetaDot2].
    For the first link, an angle of 0 corresponds to the link pointing downwards.
    The angle of the second link is relative to the angle of the first link.
    An angle of 0 corresponds to having the same angle between the two links.
    A state of [1, 0, 1, 0, ..., ...] means that both links point downwards.
    **ACTIONS:**
    The action is either applying +1, 0 or -1 torque on the joint between
    the two pendulum links.
    **REFERENCE:**
    .. warning::
        This version of the domain uses the Runge-Kutta method for integrating
        the system dynamics and is more realistic, but also considerably harder
        than the original version which employs Euler integration,
        see the AcrobotLegacy class.
    """

    dt = 0.2

    LINK_LENGTH_1 = 1.0  # [m]
    LINK_LENGTH_2 = 1.0  # [m]
    LINK_MASS_1 = 1.0  #: [kg] mass of link 1
    LINK_MASS_2 = 1.0  #: [kg] mass of link 2
    LINK_COM_POS_1 = 0.5  #: [m] position of the center of mass of link 1
    LINK_COM_POS_2 = 0.5  #: [m] position of the center of mass of link 2
    LINK_MOI = 1.0  #: moments of inertia for both links

    MAX_VEL_1 = 4 * pi
    MAX_VEL_2 = 9 * pi

    AVAIL_TORQUE = jnp.array([-1.0, 0.0, +1])

    torque_noise_max = 0.0

    #: use dynamics equations from the nips paper or the book
    book_or_nips = "book"
    action_arrow = None
    domain_fig = None
    actions_num = 3

    def __init__(self, seed=0, horizon=50):
        high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2],
                        dtype=np.float32)
        low = -high
        self.random = Random(seed)
        self.observation_space = spaces.Box(low=low,
                                            high=high,
                                            dtype=np.float32)
        self.action_space = spaces.Discrete(3)
        self.action_dim = 1
        self.H = horizon
        self.nsamples = 0

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            # Augment the state with our force action so it can be passed to _dsdt
            augmented_state = jnp.append(state, action)

            new_state = rk4(self._dsdt, augmented_state, [0, self.dt])
            # only care about final timestep of integration returned by integrator
            new_state = new_state[-1]
            new_state = new_state[:4]  # omit action
            # ODEINT IS TOO SLOW!
            # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt])
            # self.s_continuous = ns_continuous[-1] # We only care about the state
            # at the ''final timestep'', self.dt

            new_state = jax.ops.index_update(new_state, 0,
                                             wrap(new_state[0], -pi, pi))
            new_state = jax.ops.index_update(new_state, 1,
                                             wrap(new_state[1], -pi, pi))
            new_state = jax.ops.index_update(
                new_state, 2,
                bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1))
            new_state = jax.ops.index_update(
                new_state, 3,
                bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2))

            return new_state

        # @jax.jit
        def c(x, u):
            return u[0]**2 + 1 - jnp.exp(0.5 * cos(x[0]) + 0.5 * cos(x[1]) - 1)

        self.dynamics = _dynamics
        self.reward_fn = c

        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )
        self.reset()

    def reset(self):
        self.state = jax.random.uniform(
            self.random.generate_key(),
            shape=(4, ),
            minval=-0.1,
            maxval=0.1
            # self.random.generate_key(), shape=(6,), minval=-0.1, maxval=0.1
        )
        return self.state

    def step(self, action):
        # torque = self.AVAIL_TORQUE[action] # discrete action space
        torque = bound(action, -1.0, 1.0)  # continuous action space

        # Add noise to the force action
        if self.torque_noise_max > 0:
            torque += self.np_random.uniform(-self.torque_noise_max,
                                             self.torque_noise_max)
        self.state = self.dynamics(self.state, torque)
        terminal = self._terminal()
        # reward = -1.0 + terminal # openAI cost function
        reward = self.reward_fn(self.state, action)

        # TODO: should this return self.state (dim 4) or self.observation (dim 6)?
        return self.state, reward, terminal, {}

    @property
    def observation(self):
        return jnp.array([
            cos(self.state[0]),
            sin(self.state[0]),
            cos(self.state[1]),
            sin(self.state[1]),
            self.state[2],
            self.state[3],
        ])

    def _terminal(self):
        return -cos(self.state[0]) - cos(self.state[1] + self.state[0]) > 1.0

    def _dsdt(self, augmented_state, t):
        m1 = self.LINK_MASS_1
        m2 = self.LINK_MASS_2
        l1 = self.LINK_LENGTH_1
        lc1 = self.LINK_COM_POS_1
        lc2 = self.LINK_COM_POS_2
        I1 = self.LINK_MOI
        I2 = self.LINK_MOI
        g = 9.8
        a = augmented_state[-1]
        s = augmented_state[:-1]

        theta1 = s[0]
        theta2 = s[1]
        dtheta1 = s[2]
        dtheta2 = s[3]

        d1 = m1 * lc1**2 + m2 * (l1**2 + lc2**2 +
                                 2 * l1 * lc2 * cos(theta2)) + I1 + I2
        d2 = m2 * (lc2**2 + l1 * lc2 * cos(theta2)) + I2
        phi2 = m2 * lc2 * g * cos(theta1 + theta2 - pi / 2.0)
        phi1 = (-m2 * l1 * lc2 * dtheta2**2 * sin(theta2) -
                2 * m2 * l1 * lc2 * dtheta2 * dtheta1 * sin(theta2) +
                (m1 * lc1 + m2 * l1) * g * cos(theta1 - pi / 2) + phi2)
        if self.book_or_nips == "nips":
            # the following line is consistent with the description in the
            # paper
            ddtheta2 = (a + d2 / d1 * phi1 - phi2) / (m2 * lc2**2 + I2 -
                                                      d2**2 / d1)
        else:
            # the following line is consistent with the java implementation and the
            # book
            ddtheta2 = (a + d2 / d1 * phi1 - m2 * l1 * lc2 * dtheta1**2 *
                        sin(theta2) - phi2) / (m2 * lc2**2 + I2 - d2**2 / d1)
        ddtheta1 = -(d2 * ddtheta2 + phi1) / d1
        return (dtheta1, dtheta2, ddtheta1, ddtheta2, 0.0)  # 4-state version
Exemple #13
0
    def __init__(self, goal_velocity=0, seed=0, horizon=50):
        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2
        self.max_position = 0.6
        self.max_speed = 0.07
        self.goal_position = 0.45  # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
        self.goal_velocity = goal_velocity
        self.power = 0.0015
        self.H = horizon
        self.action_dim = 1
        self.random = Random(seed)

        self.low_state = np.array([self.min_position, -self.max_speed],
                                  dtype=np.float32)
        self.high_state = np.array([self.max_position, self.max_speed],
                                   dtype=np.float32)

        self.action_space = spaces.Box(low=self.min_action,
                                       high=self.max_action,
                                       shape=(1, ),
                                       dtype=np.float32)
        self.observation_space = spaces.Box(low=self.low_state,
                                            high=self.high_state,
                                            dtype=np.float32)
        self.nsamples = 0

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            position = state[0]
            velocity = state[1]

            force = jnp.minimum(jnp.maximum(action, self.min_action),
                                self.max_action)

            velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
            velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

            position += velocity
            position = jnp.clip(position, self.min_position, self.max_position)
            reset_velocity = (position == self.min_position) & (velocity < 0)
            # print('state.shape = ' + str(state.shape))
            # print('position.shape = ' + str(position.shape))
            # print('velocity.shape = ' + str(velocity.shape))
            # print('reset_velocity.shape = ' + str(reset_velocity.shape))
            velocity = jax.lax.cond(reset_velocity[0], lambda x: jnp.zeros(
                (1, )), lambda x: x, velocity)
            # print('velocity.shape AFTER = ' + str(velocity.shape))
            return jnp.reshape(jnp.array([position, velocity]), (2, ))

        @jax.jit
        def c(x, u):
            position, velocity = self.state[0], self.state[1]
            done = (position >= self.goal_position) & (velocity >=
                                                       self.goal_velocity)
            return -100.0 * done + 0.1 * (u[0] + 1)**2

        self.reward_fn = c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )

        self.reset()
Exemple #14
0
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""deluca/tests/utils/test_experiment.py"""
import jax
import numpy as np
import pytest

from deluca.utils import Random
from deluca.utils.experiment import experiment

random = Random()


@pytest.mark.parametrize("shape", [(), (1, ), (1, 2), (1, 2, 3)])
@pytest.mark.parametrize("num_args", [1, 2, 10])
def test_experiment(shape, num_args):
    """Test normal experiment behavior"""
    args = [(
        jax.random.uniform(random.generate_key(), shape=shape),
        jax.random.uniform(random.generate_key(), shape=shape),
    ) for _ in range(num_args)]

    @experiment("a,b", args)
    def dummy(a, b):
        """dummy"""
        return a + b
Exemple #15
0
class LDS(Env):
    def __init__(self,
                 state_size=1,
                 action_size=1,
                 A=None,
                 B=None,
                 C=None,
                 seed=0,
                 noise="normal"):

        if A is not None:
            assert (
                A.shape[0] == state_size and A.shape[1] == state_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        if B is not None:
            assert (
                B.shape[0] == state_size and B.shape[1] == action_size
            ), "ERROR: Your input dynamics matrix does not have the correct shape."

        self.viewer = None
        self.random = Random(seed)
        self.noise = noise

        self.state_size, self.action_size = state_size, action_size
        self.proj_vector1 = jax.random.normal(self.random.generate_key(),
                                              (state_size, ))
        self.proj_vector1 = self.proj_vector1 / jnp.linalg.norm(
            self.proj_vector1)
        self.proj_vector2 = jax.random.normal(self.random.generate_key(),
                                              (state_size, ))
        self.proj_vector2 = self.proj_vector2 / jnp.linalg.norm(
            (self.proj_vector2))
        print('self.proj_vector1:' + str(self.proj_vector1))
        print('self.proj_vector2:' + str(self.proj_vector2))

        self.A = (A if A is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, state_size)))

        self.B = (B if B is not None else jax.random.normal(
            self.random.generate_key(), shape=(state_size, action_size)))

        self.C = (C if C is not None else jax.numpy.identity(self.state_size))

        self.t = 0

        self.reset()

    def step(self, action):
        self.state = self.A @ self.state + self.B @ action
        if self.noise == "normal":
            self.state += np.random.normal(0,
                                           0.1,
                                           size=(self.state_size,
                                                 self.action_size))
        self.obs = self.C @ self.state

    @jax.jit
    def dynamics(self, state, action):
        new_state = self.A @ state + self.B @ action
        return new_state

    def reset(self):
        self.state = jax.random.normal(self.random.generate_key(),
                                       shape=(self.state_size, 1))
        self.obs = self.C @ self.state

    def render(self, mode="human"):
        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(1000, 1000)
            self.viewer.set_bounds(-2.5, 2.5, -2.5, 2.5)
            fname = path.dirname(__file__) + "/classic/assets/lds_arrow.png"
            self.img = rendering.Image(fname, 0.35, 0.35)
            self.img.set_color(1.0, 1.0, 1.0)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)
            fnamewind = path.dirname(__file__) + "/classic/assets/lds_grid.png"
            self.imgwind = rendering.Image(fnamewind, 5.0, 5.0)
            self.imgwind.set_color(0.5, 0.5, 0.5)
            self.imgtranswind = rendering.Transform()
            self.imgwind.add_attr(self.imgtranswind)

        self.viewer.add_onetime(self.imgwind)
        self.viewer.add_onetime(self.img)
        cur_x, cur_y = self.imgtrans.translation
        # new_x, new_y = self.state[0], self.state[1]
        new_x, new_y = jnp.dot(self.state.squeeze(),
                               self.proj_vector1), jnp.dot(
                                   self.state.squeeze(), self.proj_vector2)
        diff_x = new_x - cur_x
        diff_y = new_y - cur_y
        new_rotation = jnp.arctan2(diff_y, diff_x)
        self.imgtrans.set_translation(new_x, new_y)
        self.imgtrans.set_rotation(new_rotation)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None
Exemple #16
0
class MountainCar(Env):
    metadata = {'render.modes': ['human', 'rgb_array']}
    def __init__(self, goal_velocity=0, seed=0, horizon=50):
        self.viewer = None
        self.min_action = -1.0
        self.max_action = 1.0
        self.min_position = -1.2
        self.max_position = 0.6
        self.max_speed = 0.07
        self.goal_position = 0.45  # was 0.5 in gym, 0.45 in Arnaud de Broissia's version
        self.goal_velocity = goal_velocity
        self.power = 0.0015
        self.H = horizon
        self.action_dim = 1
        self.random = Random(seed)

        self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32)
        self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32)

        self.action_space = spaces.Box(
            low=self.min_action, high=self.max_action, shape=(1,), dtype=np.float32
        )
        self.observation_space = spaces.Box(
            low=self.low_state, high=self.high_state, dtype=np.float32
        )
        self.nsamples = 0
        
        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            position = state[0]
            velocity = state[1]

            force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action)

            velocity += force * self.power - 0.0025 * jnp.cos(3 * position)
            velocity = jnp.clip(velocity, -self.max_speed, self.max_speed)

            position += velocity
            position = jnp.clip(position, self.min_position, self.max_position)
            reset_velocity = (position == self.min_position) & (velocity < 0)
            # print('state.shape = ' + str(state.shape))
            # print('position.shape = ' + str(position.shape))
            # print('velocity.shape = ' + str(velocity.shape))
            # print('reset_velocity.shape = ' + str(reset_velocity.shape))
            velocity = jax.lax.cond(reset_velocity[0], velocity, lambda x: jnp.zeros((1,)), velocity, lambda x: x)
            # print('velocity.shape AFTER = ' + str(velocity.shape))
            return jnp.reshape(jnp.array([position, velocity]), (2,))
        
        @jax.jit
        def c(x, u):
            position, velocity = self.state[0], self.state[1]
            done = (position >= self.goal_position) & (velocity >= self.goal_velocity)
            return -100.0 * done + 0.1*(u[0]+1)**2
        self.reward_fn = c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
                _dynamics,
                jax.jacfwd(_dynamics, argnums=0),
                jax.jacfwd(_dynamics, argnums=1),
            )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
                c,
                jax.grad(c, argnums=0),
                jax.grad(c, argnums=1),
                jax.hessian(c, argnums=0),
                jax.hessian(c, argnums=1),
            )
                           

        self.reset()

    def step(self, action):
        self.state = self.dynamics(self.state, action)
        position = self.state[0]
        velocity = self.state[1]

        # Convert a possible numpy bool to a Python bool.
        done = (position >= self.goal_position) & (velocity >= self.goal_velocity)

        # reward = 100.0 * done
        # reward -= jnp.power(action, 2) * 0.1
        reward = self.reward_fn(self.state, action)

        return self.state, reward, done, {}

    def reset(self):
        self.state = jnp.array(
            [jax.random.uniform(self.random.generate_key(), minval=-0.6, maxval=0.4), 0]
        )
        return self.state

    def _height(self, xs):
        return jnp.sin(3 * xs) * 0.45 + 0.55

    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.max_position - self.min_position
        scale = screen_width / world_width
        carwidth = 40
        carheight = 20

        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(screen_width, screen_height)
            xs = np.linspace(self.min_position, self.max_position, 100)
            ys = self._height(xs)
            xys = list(zip((xs - self.min_position) * scale, ys * scale))

            self.track = rendering.make_polyline(xys)
            self.track.set_linewidth(4)
            self.viewer.add_geom(self.track)

            clearance = 10

            l, r, t, b = -carwidth / 2, carwidth / 2, carheight, 0
            car = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            car.add_attr(rendering.Transform(translation=(0, clearance)))
            self.cartrans = rendering.Transform()
            car.add_attr(self.cartrans)
            self.viewer.add_geom(car)
            frontwheel = rendering.make_circle(carheight / 2.5)
            frontwheel.set_color(0.5, 0.5, 0.5)
            frontwheel.add_attr(rendering.Transform(translation=(carwidth / 4, clearance)))
            frontwheel.add_attr(self.cartrans)
            self.viewer.add_geom(frontwheel)
            backwheel = rendering.make_circle(carheight / 2.5)
            backwheel.add_attr(rendering.Transform(translation=(-carwidth / 4, clearance)))
            backwheel.add_attr(self.cartrans)
            backwheel.set_color(0.5, 0.5, 0.5)
            self.viewer.add_geom(backwheel)
            flagx = (self.goal_position - self.min_position) * scale
            flagy1 = self._height(self.goal_position) * scale
            flagy2 = flagy1 + 50
            flagpole = rendering.Line((flagx, flagy1), (flagx, flagy2))
            self.viewer.add_geom(flagpole)
            flag = rendering.FilledPolygon(
                [(flagx, flagy2), (flagx, flagy2 - 10), (flagx + 25, flagy2 - 5)]
            )
            flag.set_color(0.8, 0.8, 0)
            self.viewer.add_geom(flag)

        pos = self.state[0]
        self.cartrans.set_translation((pos - self.min_position) * scale, self._height(pos) * scale)
        self.cartrans.set_rotation(math.cos(3 * pos))

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Exemple #17
0
class Deep(Agent):
    """
    Generic deep controller that uses zero-order methods to train on an
    environment.
    """

    def __init__(
        self,
        env: Env,
        learning_rate: Real = 0.001,
        gamma: Real = 0.99,
        max_episode_length: int = 500,
        seed: int = 0,
    ) -> None:
        """
        Description: initializes the Deep agent

        Args:
            env (Env): a deluca environment
            learning_rate (Real):
            gamma (Real):
            max_episode_length (int):
            seed (int):

        Returns:
            None
        """
        # Create gym and seed numpy
        self.env = env
        self.max_episode_length = max_episode_length
        self.lr = learning_rate
        self.gamma = gamma

        self.random = Random(seed)

        self.reset()

    def reset(self) -> None:
        """
        Description: reset agent

        Args:
            None

        Returns:
            None
        """
        # Init weight
        self.W = jax.random.uniform(
            self.random.generate_key(),
            shape=(self.env.state_size, len(self.env.action_space)),
            minval=0,
            maxval=1,
        )

        # Keep stats for final print of graph
        self.episode_rewards = []

        self.current_episode_length = 0
        self.current_episode_reward = 0
        self.episode_rewards = jnp.zeros(self.max_episode_length)
        self.episode_grads = jnp.zeros((self.max_episode_length, self.W.shape[0], self.W.shape[1]))

    def policy(self, state: jnp.ndarray, w: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Policy that maps state to action parameterized by w

        Args:
            state (jnp.ndarray):
            w (jnp.ndarray):
        """
        z = jnp.dot(state, w)
        exp = jnp.exp(z)
        return exp / jnp.sum(exp)

    def softmax_grad(self, softmax: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Vectorized softmax Jacobian

        Args:
            softmax (jnp.ndarray)
        """
        s = softmax.reshape(-1, 1)
        return jnp.diagflat(s) - jnp.dot(s, s.T)

    def __call__(self, state: jnp.ndarray):
        """
        Description: provide an action given a state

        Args:
            state (jnp.ndarray):

        Returns:
            jnp.ndarray: action to take
        """
        self.state = state
        self.probs = self.policy(state, self.W)
        self.action = jax.random.choice(
            self.random.generate_key(), a=self.env.action_space, p=self.probs
        )
        return self.action

    def feed(self, reward: Real) -> None:
        """
        Description: compute gradient and save with reward in memory for weight updates

        Args:
            reward (Real):

        Returns:
            None
        """
        dsoftmax = self.softmax_grad(self.probs)[self.action, :]
        dlog = dsoftmax / self.probs[self.action]
        grad = self.state.reshape(-1, 1) @ dlog.reshape(1, -1)

        self.episode_rewards = jax.ops.index_update(
            self.episode_rewards, self.current_episode_length, reward
        )
        self.episode_grads = jax.ops.index_update(
            self.episode_grads, self.current_episode_length, grad
        )
        self.current_episode_length += 1

    def update(self) -> None:
        """
        Description: update weights
        
        Args:
            None

        Returns:
            None
        """
        for i in range(self.current_episode_length):
            # Loop through everything that happend in the episode and update
            # towards the log policy gradient times **FUTURE** reward
            self.W += self.lr * self.episode_grads[i] + jnp.sum(
                jnp.array(
                    [
                        r * (self.gamma ** r)
                        for r in self.episode_rewards[i : self.current_episode_length]
                    ]
                )
            )

        # reset episode length
        self.current_episode_length = 0
Exemple #18
0
class CartPole(Env):
    """
    Description:
        A pole is attached by an un-actuated joint to a cart, which moves along
        a frictionless track. The pendulum starts upright, and the goal is to
        prevent it from falling over by increasing and reducing the cart's
        velocity.

    Source:
        This environment corresponds to the version of the cart-pole problem
        described by Barto, Sutton, and Anderson

    Observation:
        Type: Box(4)
        Num	Observation               Min             Max
        0	Cart Position             -4.8            4.8
        1	Cart Velocity             -Inf            Inf
        2	Pole Angle                -24 deg         24 deg
        3	Pole Velocity At Tip      -Inf            Inf

    Actions:
        Type: Discrete(2)
        Num	Action
        0	Push cart to the left
        1	Push cart to the right

        Note: The amount the velocity that is reduced or increased is not
        fixed; it depends on the angle the pole is pointing. This is because
        the center of gravity of the pole increases the amount of energy needed
        to move the cart underneath it

    Reward:
        Reward is 1 for every step taken, including the termination step

    Starting State:
        All observations are assigned a uniform random value in [-0.05..0.05]

    Episode Termination:
        Pole Angle is more than 12 degrees.
        Cart Position is more than 2.4 (center of the cart reaches the edge of
        the display).
        Episode length is greater than 200. (not really - only in make, not in
                                            the actual CartPoleEnv class)
        Solved Requirements:
        Considered solved when the average reward is greater than or equal to
        195.0 over 100 consecutive trials.
    """
    metadata = {'render.modes': ['human', 'rgb_array']}

    def __init__(self, reward_fn=None, seed=0):
        self.viewer = None
        self.gravity = 9.8
        self.masscart = 1.0
        self.masspole = 0.1
        self.total_mass = self.masspole + self.masscart
        self.length = 0.5  # actually half the pole's length
        self.polemass_length = self.masspole * self.length
        self.force_mag = 10.0
        self.tau = 0.02  # seconds between state updates
        self.kinematics_integrator = "euler"

        # Angle at which to fail the episode
        # Angle at which to fail the episode
        self.theta_threshold_radians = 12 * 2 * math.pi / 360
        self.x_threshold = 2.4

        self.random = Random(seed)

        # Angle limit set to 2 * theta_threshold_radians so failing observation
        # is still within bounds.
        high = np.array(
            [
                self.x_threshold * 2,
                np.finfo(np.float32).max,
                self.theta_threshold_radians * 2,
                np.finfo(np.float32).max,
            ],
            dtype=np.float32,
        )

        # self.action_space = jnp.array([0, 1])
        self.action_space = gym.spaces.Box(low=0, high=1, shape=(1, ))
        # TODO: no longer use gym.spaces
        self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32)

        self.state_size, self.action_size = 4, 1
        self.observation_size = self.state_size

        self.reset()

    # @jax.jit
    def dynamics(self, state, action):

        x, x_dot, theta, theta_dot = state

        force = jax.lax.cond(action == 1, self.force_mag, lambda x: x,
                             self.force_mag, lambda x: -x)

        costheta = jnp.cos(theta)
        sintheta = jnp.sin(theta)

        temp = (force + self.polemass_length * theta_dot**2 *
                sintheta) / self.total_mass
        thetaacc = (self.gravity * sintheta - costheta * temp) / (
            self.length *
            (4.0 / 3.0 - self.masspole * costheta**2 / self.total_mass))
        xacc = temp - self.polemass_length * thetaacc * costheta / self.total_mass

        if self.kinematics_integrator == "euler":
            x = x + self.tau * x_dot
            x_dot = x_dot + self.tau * xacc
            theta = theta + self.tau * theta_dot
            theta_dot = theta_dot + self.tau * thetaacc
        else:  # semi-implicit euler
            x_dot = x_dot + self.tau * xacc
            x = x + self.tau * x_dot
            theta_dot = theta_dot + self.tau * thetaacc
            theta = theta + self.tau * theta_dot

        return jnp.array([x, x_dot, theta, theta_dot])

    def reset(self):
        self.state = jax.random.uniform(self.random.get_key(),
                                        shape=(4, ),
                                        minval=-0.05,
                                        maxval=0.05)
        return self.state

    def step(self, action):
        print('self.state:' + str(self.state))
        print('action:' + str(action))
        print('type(self.state):' + str(type(self.state)))
        print('type(action):' + str(type(action)))
        self.state = self.dynamics(self.state, action)
        x, x_dot, theta, theta_dot = self.state

        done = jax.lax.cond(
            (jnp.abs(x) > jnp.abs(self.x_threshold)) +
            (jnp.abs(theta) > jnp.abs(self.theta_threshold_radians)),
            None,
            lambda done: True,
            None,
            lambda done: False,
        )

        reward = 1 - done

        return self.state, reward, done, {}

    def render(self, mode="human"):
        screen_width = 600
        screen_height = 400

        world_width = self.x_threshold * 2
        scale = screen_width / world_width
        carty = 100  # TOP OF CART
        polewidth = 10.0
        polelen = scale * (2 * self.length)
        cartwidth = 50.0
        cartheight = 30.0

        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(screen_width, screen_height)
            l, r, t, b = -cartwidth / 2, cartwidth / 2, cartheight / 2, -cartheight / 2
            axleoffset = cartheight / 4.0
            cart = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            self.carttrans = rendering.Transform()
            cart.add_attr(self.carttrans)
            self.viewer.add_geom(cart)
            l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
            pole = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
            pole.set_color(0.8, 0.6, 0.4)
            self.poletrans = rendering.Transform(translation=(0, axleoffset))
            pole.add_attr(self.poletrans)
            pole.add_attr(self.carttrans)
            self.viewer.add_geom(pole)
            self.axle = rendering.make_circle(polewidth / 2)
            self.axle.add_attr(self.poletrans)
            self.axle.add_attr(self.carttrans)
            self.axle.set_color(0.5, 0.5, 0.8)
            self.viewer.add_geom(self.axle)
            self.track = rendering.Line((0, carty), (screen_width, carty))
            self.track.set_color(0, 0, 0)
            self.viewer.add_geom(self.track)

            self._pole_geom = pole

        if self.state is None:
            return None

        # Edit the pole polygon vertex
        pole = self._pole_geom
        l, r, t, b = -polewidth / 2, polewidth / 2, polelen - polewidth / 2, -polewidth / 2
        pole.v = [(l, b), (l, t), (r, t), (r, b)]

        x = self.state
        cartx = x[0] * scale + screen_width / 2.0  # MIDDLE OF CART
        self.carttrans.set_translation(cartx, carty)
        self.poletrans.set_rotation(-x[2])

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Exemple #19
0
    def __init__(
        self,
        A: jnp.ndarray,
        B: jnp.ndarray,
        C: jnp.ndarray = None,
        K: jnp.ndarray = None,
        cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None,
        m: int = 10,
        h: int = 50,
        lr_scale: Real = 0.03,
        decay: bool = True,
        RM: int = 1000,
        seed: int = 0,
    ) -> None:
        """
        Description: Initialize the dynamics of the model.

        Args:
            A (jnp.ndarray): system dynamics
            B (jnp.ndarray): system dynamics
            C (jnp.ndarray): system dynamics
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
            start_time (int):
            cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]):
            H (postive int): history of the controller
            HH (positive int): history of the system
            lr_scale (Real):
            decay (boolean):
            seed (int):
        """

        cost_fn = cost_fn or quad_loss

        self.random = Random(seed)

        d_state, d_action = B.shape  # State & Action Dimensions

        C = jnp.identity(d_state) if C is None else C

        d_obs = C.shape[0]  # Observation Dimension

        self.t = 0  # Time Counter (for decaying learning rate)

        self.m, self.h = m, h

        self.lr_scale, self.decay = lr_scale, decay

        self.RM = RM

        # Construct truncated markov operator G
        self.G = jnp.zeros((h, d_obs, d_action))
        A_power = jnp.identity(d_state)
        for i in range(h):
            self.G = self.G.at[i].set(C @ A_power @ B)
            A_power = A_power @ A

        # Model Parameters
        # initial linear policy / perturbation contributions / bias
        self.K = K if K is not None else jnp.zeros((d_action, d_obs))

        self.M = lr_scale * jax.random.normal(
            self.random.generate_key(), shape=(m, d_action, d_obs)
        )

        # Past m nature y's such that y_nat[0] is the most recent
        self.y_nat = jnp.zeros((m, d_obs, 1))

        # Past h u's such that u[0] is the most recent
        self.us = jnp.zeros((h, d_action, 1))

        def policy_loss(M, G, y_nat, us):
            """Surrogate cost function"""

            def action(obs):
                """Action function"""
                return -self.K @ obs + jnp.tensordot(M, y_nat, axes=([0, 2], [0, 1]))

            final_state = y_nat[0] + jnp.tensordot(G, us, axes=([0, 2], [0, 1]))
            return cost_fn(final_state, action(final_state))

        self.policy_loss = policy_loss
        self.grad = jit(grad(policy_loss))
Exemple #20
0
class DRC(Agent):
    def __init__(
        self,
        A: jnp.ndarray,
        B: jnp.ndarray,
        C: jnp.ndarray = None,
        K: jnp.ndarray = None,
        cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None,
        m: int = 10,
        h: int = 50,
        lr_scale: Real = 0.03,
        decay: bool = True,
        RM: int = 1000,
        seed: int = 0,
    ) -> None:
        """
        Description: Initialize the dynamics of the model.

        Args:
            A (jnp.ndarray): system dynamics
            B (jnp.ndarray): system dynamics
            C (jnp.ndarray): system dynamics
            Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu)
            K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain.
            start_time (int):
            cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]):
            H (postive int): history of the controller
            HH (positive int): history of the system
            lr_scale (Real):
            decay (boolean):
            seed (int):
        """

        cost_fn = cost_fn or quad_loss

        self.random = Random(seed)

        d_state, d_action = B.shape  # State & Action Dimensions

        C = jnp.identity(d_state) if C is None else C

        d_obs = C.shape[0]  # Observation Dimension

        self.t = 0  # Time Counter (for decaying learning rate)

        self.m, self.h = m, h

        self.lr_scale, self.decay = lr_scale, decay

        self.RM = RM

        # Construct truncated markov operator G
        self.G = jnp.zeros((h, d_obs, d_action))
        A_power = jnp.identity(d_state)
        for i in range(h):
            self.G = self.G.at[i].set(C @ A_power @ B)
            A_power = A_power @ A

        # Model Parameters
        # initial linear policy / perturbation contributions / bias
        self.K = K if K is not None else jnp.zeros((d_action, d_obs))

        self.M = lr_scale * jax.random.normal(
            self.random.generate_key(), shape=(m, d_action, d_obs)
        )

        # Past m nature y's such that y_nat[0] is the most recent
        self.y_nat = jnp.zeros((m, d_obs, 1))

        # Past h u's such that u[0] is the most recent
        self.us = jnp.zeros((h, d_action, 1))

        def policy_loss(M, G, y_nat, us):
            """Surrogate cost function"""

            def action(obs):
                """Action function"""
                return -self.K @ obs + jnp.tensordot(M, y_nat, axes=([0, 2], [0, 1]))

            final_state = y_nat[0] + jnp.tensordot(G, us, axes=([0, 2], [0, 1]))
            return cost_fn(final_state, action(final_state))

        self.policy_loss = policy_loss
        self.grad = jit(grad(policy_loss))

    def __call__(self, obs: jnp.ndarray) -> jnp.ndarray:
        """
        Description: Return the action based on current state and internal parameters.

        Args:
            state (jnp.ndarray): current state

        Returns:
           jnp.ndarray: action to take
        """
        # update y_nat
        self.update_noise(obs)

        # get action
        action = self.get_action(obs)

        # update Parameters
        self.update_params(obs, action)

        return action

    def get_action(self, obs: jnp.ndarray) -> jnp.ndarray:
        """
        Description: get action from state.

        Args:
            state (jnp.ndarray):

        Returns:
            jnp.ndarray
        """

        return -self.K @ obs + jnp.tensordot(self.M, self.y_nat, axes=([0, 2], [0, 1]))

    def update(self, obs: jnp.ndarray, u: jnp.ndarray) -> None:
        self.update_noise(obs)
        self.update_params(obs, u)

    def update_noise(self, obs: jnp.ndarray) -> None:
        y_nat = obs - jnp.tensordot(self.G, self.us, axes=([0, 2], [0, 1]))
        self.y_nat = jnp.roll(self.y_nat, 1, axis=0)
        self.y_nat = self.y_nat.at[0].set(y_nat)

    def update_params(self, obs: jnp.ndarray, u: jnp.ndarray) -> None:
        """
        Description: update agent internal state.

        Args:
            state (jnp.ndarray):

        Returns:
            None
        """

        # update parameters
        delta_M = self.grad(self.M, self.G, self.y_nat, self.us)
        lr = self.lr_scale
        # lr *= (1/ (self.t+1)) if self.decay else 1
        lr = jax.lax.cond(self.decay, lambda x: x * 1 / (self.t + 1), lambda x: 1.0, lr)

        self.M -= lr * delta_M
        # if(jnp.linalg.norm(self.M) > self.RM):
        #     self.M *= (self.RM / jnp.linalg.norm(self.M))

        self.M = jax.lax.cond(
            jnp.linalg.norm(self.M) > self.RM,
            lambda x: x * (self.RM / jnp.linalg.norm(self.M)),
            lambda x: x,
            self.M,
        )

        # update us
        self.us = jnp.roll(self.us, 1, axis=0)
        self.us = self.us.at[0].set(u)

        self.t += 1
Exemple #21
0
class Pendulum(Env):
    max_speed = 8.0
    max_torque = 2.0  # gym 2.
    high = np.array([1.0, 1.0, max_speed])

    action_space = spaces.Box(low=-max_torque,
                              high=max_torque,
                              shape=(1, ),
                              dtype=np.float32)
    observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)
    metadata = {'render.modes': ['human', 'rgb_array']}

    def __init__(self, reward_fn=None, seed=0, horizon=50):
        # self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1
        self.action_dim = 1  # redundant with action_size but needed by ILQR

        self.H = horizon

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize
        self.nsamples = 0
        self.last_u = None
        self.random = Random(seed)

        self.reset()

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            self.last_u = action
            th, thdot = state
            g = 10.0
            m = 1.0
            ell = 1.0
            dt = self.dt

            # Do not limit the control signals
            action = jnp.clip(action, -self.max_torque, self.max_torque)

            newthdot = (thdot +
                        (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 /
                         (m * ell**2) * action) * dt)
            newth = th + newthdot * dt
            newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

            return jnp.reshape(jnp.array([newth, newthdot]), (2, ))

        @jax.jit
        def c(x, u):
            # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
            return angle_normalize(x[0])**2 + .1 * (u[0]**2)

        self.reward_fn = reward_fn or c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )

    def reset(self):
        th = jax.random.uniform(self.random.generate_key(),
                                minval=-jnp.pi,
                                maxval=jnp.pi)
        thdot = jax.random.uniform(self.random.generate_key(),
                                   minval=-1.0,
                                   maxval=1.0)

        self.state = jnp.array([th, thdot])

        return self.state

    def render(self, mode="human"):
        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(500, 500)
            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
            rod = rendering.make_capsule(1, 0.2)
            rod.set_color(0.8, 0.3, 0.3)
            self.pole_transform = rendering.Transform()
            rod.add_attr(self.pole_transform)
            self.viewer.add_geom(rod)
            axle = rendering.make_circle(0.05)
            axle.set_color(0, 0, 0)
            self.viewer.add_geom(axle)
            fname = path.join(path.dirname(__file__), "assets/clockwise.png")
            self.img = rendering.Image(fname, 1.0, 1.0)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)

        self.viewer.add_onetime(self.img)
        self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
        if self.last_u:
            self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")
Exemple #22
0
class Pendulum(Env):
    max_speed = 8.0
    max_torque = 2.0  # gym 2.
    high = np.array([1.0, 1.0, max_speed])

    action_space = gym.spaces.Box(low=-max_torque, high=max_torque, shape=(1,), dtype=np.float32)
    observation_space = gym.spaces.Box(low=-high, high=high, dtype=np.float32)

    def __init__(self, reward_fn=None, seed=0):
        self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize

        self.random = Random(seed)

        self.reset()

    @jax.jit
    def dynamics(self, state, action):
        th, thdot = state
        g = 10.0
        m = 1.0
        ell = 1.0
        dt = self.dt

        # Do not limit the control signals
        action = jnp.clip(action, -self.max_torque, self.max_torque)

        newthdot = (
            thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell ** 2) * action) * dt
        )
        newth = th + newthdot * dt
        newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

        return jnp.array([newth, newthdot])

    def reset(self):
        th = jax.random.uniform(self.random.generate_key(), minval=-jnp.pi, maxval=jnp.pi)
        thdot = jax.random.uniform(self.random.generate_key(), minval=-1.0, maxval=1.0)

        self.state = jnp.array([th, thdot])

        return self.state

    def render(self, mode="human"):
        if self.viewer is None:
            from gym.envs.classic_control import rendering

            self.viewer = rendering.Viewer(500, 500)
            self.viewer.set_bounds(-2.2, 2.2, -2.2, 2.2)
            rod = rendering.make_capsule(1, 0.2)
            rod.set_color(0.8, 0.3, 0.3)
            self.pole_transform = rendering.Transform()
            rod.add_attr(self.pole_transform)
            self.viewer.add_geom(rod)
            axle = rendering.make_circle(0.05)
            axle.set_color(0, 0, 0)
            self.viewer.add_geom(axle)
            fname = path.join(path.dirname(__file__), "assets/clockwise.png")
            self.img = rendering.Image(fname, 1.0, 1.0)
            self.imgtrans = rendering.Transform()
            self.img.add_attr(self.imgtrans)

        self.viewer.add_onetime(self.img)
        self.pole_transform.set_rotation(self.state[0] + np.pi / 2)
        if self.last_u:
            self.imgtrans.scale = (-self.last_u / 2, np.abs(self.last_u) / 2)

        return self.viewer.render(return_rgb_array=mode == "rgb_array")