def __init__( self, env: Env, learning_rate: Real = 0.001, gamma: Real = 0.99, max_episode_length: int = 500, seed: int = 0, ) -> None: """ Description: initializes the Deep agent Args: env (Env): a deluca environment learning_rate (Real): gamma (Real): max_episode_length (int): seed (int): Returns: None """ # Create gym and seed numpy self.env = env self.max_episode_length = max_episode_length self.lr = learning_rate self.gamma = gamma self.random = Random(seed) self.reset()
def __init__(self, state_size=1, action_size=1, A=None, B=None, C=None, seed=0): if A is not None: assert ( A.shape[0] == state_size and A.shape[1] == state_size ), "ERROR: Your input dynamics matrix does not have the correct shape." if B is not None: assert ( B.shape[0] == state_size and B.shape[1] == action_size ), "ERROR: Your input dynamics matrix does not have the correct shape." self.random = Random(seed) self.state_size, self.action_size = state_size, action_size self.A = (A if A is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, state_size))) self.B = (B if B is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, action_size))) self.C = (C if C is not None else jax.numpy.identity(self.state_size)) self.t = 0 self.reset()
def __init__(self, goal_velocity=0, seed=0): self.min_action = -1.0 self.max_action = 1.0 self.min_position = -1.2 self.max_position = 0.6 self.max_speed = 0.07 self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version self.goal_velocity = goal_velocity self.power = 0.0015 self.random = Random(seed) self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32) self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32) self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1, ), dtype=np.float32) self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32) self.reset()
def __init__(self, reward_fn=None, seed=0, horizon=50): # self.reward_fn = reward_fn or default_reward_fn self.dt = 0.05 self.viewer = None self.state_size = 2 self.action_size = 1 self.action_dim = 1 # redundant with action_size but needed by ILQR self.H = horizon self.n, self.m = 2, 1 self.angle_normalize = angle_normalize self.nsamples = 0 self.last_u = None self.random = Random(seed) self.reset() # @jax.jit def _dynamics(state, action): self.nsamples += 1 self.last_u = action th, thdot = state g = 10.0 m = 1.0 ell = 1.0 dt = self.dt # Do not limit the control signals action = jnp.clip(action, -self.max_torque, self.max_torque) newthdot = (thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell**2) * action) * dt) newth = th + newthdot * dt newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed) return jnp.reshape(jnp.array([newth, newthdot]), (2, )) @jax.jit def c(x, u): # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2)) return angle_normalize(x[0])**2 + .1 * (u[0]**2) self.reward_fn = reward_fn or c self.dynamics = _dynamics self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), )
def __init__(self, seed=0): high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32) low = -high self.random = Random(seed) self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.action_space = spaces.Discrete(3) self.reset()
def __init__(self, reward_fn=None, seed=0): self.reward_fn = reward_fn or default_reward_fn self.dt = 0.05 self.viewer = None self.state_size = 2 self.action_size = 1 self.n, self.m = 2, 1 self.angle_normalize = angle_normalize self.random = Random(seed) self.reset()
def __init__(self, state_size=1, action_size=1, A=None, B=None, C=None, seed=0, noise="normal"): if A is not None: assert ( A.shape[0] == state_size and A.shape[1] == state_size ), "ERROR: Your input dynamics matrix does not have the correct shape." if B is not None: assert ( B.shape[0] == state_size and B.shape[1] == action_size ), "ERROR: Your input dynamics matrix does not have the correct shape." self.viewer = None self.random = Random(seed) self.noise = noise self.state_size, self.action_size = state_size, action_size self.proj_vector1 = jax.random.normal(self.random.generate_key(), (state_size, )) self.proj_vector1 = self.proj_vector1 / jnp.linalg.norm( self.proj_vector1) self.proj_vector2 = jax.random.normal(self.random.generate_key(), (state_size, )) self.proj_vector2 = self.proj_vector2 / jnp.linalg.norm( (self.proj_vector2)) print('self.proj_vector1:' + str(self.proj_vector1)) print('self.proj_vector2:' + str(self.proj_vector2)) self.A = (A if A is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, state_size))) self.B = (B if B is not None else jax.random.normal( self.random.generate_key(), shape=(state_size, action_size))) self.C = (C if C is not None else jax.numpy.identity(self.state_size)) self.t = 0 self.reset()
def __init__(self, reward_fn=None, seed=0): self.viewer = None self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 self.total_mass = self.masspole + self.masscart self.length = 0.5 # actually half the pole's length self.polemass_length = self.masspole * self.length self.force_mag = 10.0 self.tau = 0.02 # seconds between state updates self.kinematics_integrator = "euler" # Angle at which to fail the episode # Angle at which to fail the episode self.theta_threshold_radians = 12 * 2 * math.pi / 360 self.x_threshold = 2.4 self.random = Random(seed) # Angle limit set to 2 * theta_threshold_radians so failing observation # is still within bounds. high = np.array( [ self.x_threshold * 2, np.finfo(np.float32).max, self.theta_threshold_radians * 2, np.finfo(np.float32).max, ], dtype=np.float32, ) # self.action_space = jnp.array([0, 1]) self.action_space = gym.spaces.Box(low=0, high=1, shape=(1, )) # TODO: no longer use gym.spaces self.observation_space = gym.spaces.Box(-high, high, dtype=np.float32) self.state_size, self.action_size = 4, 1 self.observation_size = self.state_size self.reset()
def __init__(self, seed=0, horizon=50): high = np.array([1.0, 1.0, 1.0, 1.0, self.MAX_VEL_1, self.MAX_VEL_2], dtype=np.float32) low = -high self.random = Random(seed) self.observation_space = spaces.Box(low=low, high=high, dtype=np.float32) self.action_space = spaces.Discrete(3) self.action_dim = 1 self.H = horizon self.nsamples = 0 # @jax.jit def _dynamics(state, action): self.nsamples += 1 # Augment the state with our force action so it can be passed to _dsdt augmented_state = jnp.append(state, action) new_state = rk4(self._dsdt, augmented_state, [0, self.dt]) # only care about final timestep of integration returned by integrator new_state = new_state[-1] new_state = new_state[:4] # omit action # ODEINT IS TOO SLOW! # ns_continuous = integrate.odeint(self._dsdt, self.s_continuous, [0, self.dt]) # self.s_continuous = ns_continuous[-1] # We only care about the state # at the ''final timestep'', self.dt new_state = jax.ops.index_update(new_state, 0, wrap(new_state[0], -pi, pi)) new_state = jax.ops.index_update(new_state, 1, wrap(new_state[1], -pi, pi)) new_state = jax.ops.index_update( new_state, 2, bound(new_state[2], -self.MAX_VEL_1, self.MAX_VEL_1)) new_state = jax.ops.index_update( new_state, 3, bound(new_state[3], -self.MAX_VEL_2, self.MAX_VEL_2)) return new_state # @jax.jit def c(x, u): return u[0]**2 + 1 - jnp.exp(0.5 * cos(x[0]) + 0.5 * cos(x[1]) - 1) self.dynamics = _dynamics self.reward_fn = c self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) self.reset()
def __init__(self, goal_velocity=0, seed=0, horizon=50): self.min_action = -1.0 self.max_action = 1.0 self.min_position = -1.2 self.max_position = 0.6 self.max_speed = 0.07 self.goal_position = 0.45 # was 0.5 in gym, 0.45 in Arnaud de Broissia's version self.goal_velocity = goal_velocity self.power = 0.0015 self.H = horizon self.action_dim = 1 self.random = Random(seed) self.low_state = np.array([self.min_position, -self.max_speed], dtype=np.float32) self.high_state = np.array([self.max_position, self.max_speed], dtype=np.float32) self.action_space = spaces.Box(low=self.min_action, high=self.max_action, shape=(1, ), dtype=np.float32) self.observation_space = spaces.Box(low=self.low_state, high=self.high_state, dtype=np.float32) self.nsamples = 0 # @jax.jit def _dynamics(state, action): self.nsamples += 1 position = state[0] velocity = state[1] force = jnp.minimum(jnp.maximum(action, self.min_action), self.max_action) velocity += force * self.power - 0.0025 * jnp.cos(3 * position) velocity = jnp.clip(velocity, -self.max_speed, self.max_speed) position += velocity position = jnp.clip(position, self.min_position, self.max_position) reset_velocity = (position == self.min_position) & (velocity < 0) # print('state.shape = ' + str(state.shape)) # print('position.shape = ' + str(position.shape)) # print('velocity.shape = ' + str(velocity.shape)) # print('reset_velocity.shape = ' + str(reset_velocity.shape)) velocity = jax.lax.cond(reset_velocity[0], lambda x: jnp.zeros( (1, )), lambda x: x, velocity) # print('velocity.shape AFTER = ' + str(velocity.shape)) return jnp.reshape(jnp.array([position, velocity]), (2, )) @jax.jit def c(x, u): position, velocity = self.state[0], self.state[1] done = (position >= self.goal_position) & (velocity >= self.goal_velocity) return -100.0 * done + 0.1 * (u[0] + 1)**2 self.reward_fn = c self.dynamics = _dynamics self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), ) self.reset()
def __init__( self, A: jnp.ndarray, B: jnp.ndarray, C: jnp.ndarray = None, K: jnp.ndarray = None, cost_fn: Callable[[jnp.ndarray, jnp.ndarray], Real] = None, m: int = 10, h: int = 50, lr_scale: Real = 0.03, decay: bool = True, RM: int = 1000, seed: int = 0, ) -> None: """ Description: Initialize the dynamics of the model. Args: A (jnp.ndarray): system dynamics B (jnp.ndarray): system dynamics C (jnp.ndarray): system dynamics Q (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) R (jnp.ndarray): cost matrices (i.e. cost = x^TQx + u^TRu) K (jnp.ndarray): Starting policy (optional). Defaults to LQR gain. start_time (int): cost_fn (Callable[[jnp.ndarray, jnp.ndarray], Real]): H (postive int): history of the controller HH (positive int): history of the system lr_scale (Real): decay (boolean): seed (int): """ cost_fn = cost_fn or quad_loss self.random = Random(seed) d_state, d_action = B.shape # State & Action Dimensions C = jnp.identity(d_state) if C is None else C d_obs = C.shape[0] # Observation Dimension self.t = 0 # Time Counter (for decaying learning rate) self.m, self.h = m, h self.lr_scale, self.decay = lr_scale, decay self.RM = RM # Construct truncated markov operator G self.G = jnp.zeros((h, d_obs, d_action)) A_power = jnp.identity(d_state) for i in range(h): self.G = self.G.at[i].set(C @ A_power @ B) A_power = A_power @ A # Model Parameters # initial linear policy / perturbation contributions / bias self.K = K if K is not None else jnp.zeros((d_action, d_obs)) self.M = lr_scale * jax.random.normal( self.random.generate_key(), shape=(m, d_action, d_obs) ) # Past m nature y's such that y_nat[0] is the most recent self.y_nat = jnp.zeros((m, d_obs, 1)) # Past h u's such that u[0] is the most recent self.us = jnp.zeros((h, d_action, 1)) def policy_loss(M, G, y_nat, us): """Surrogate cost function""" def action(obs): """Action function""" return -self.K @ obs + jnp.tensordot(M, y_nat, axes=([0, 2], [0, 1])) final_state = y_nat[0] + jnp.tensordot(G, us, axes=([0, 2], [0, 1])) return cost_fn(final_state, action(final_state)) self.policy_loss = policy_loss self.grad = jit(grad(policy_loss))
# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """deluca/tests/utils/test_experiment.py""" import jax import numpy as np import pytest from deluca.utils import Random from deluca.utils.experiment import experiment random = Random() @pytest.mark.parametrize("shape", [(), (1, ), (1, 2), (1, 2, 3)]) @pytest.mark.parametrize("num_args", [1, 2, 10]) def test_experiment(shape, num_args): """Test normal experiment behavior""" args = [( jax.random.uniform(random.generate_key(), shape=shape), jax.random.uniform(random.generate_key(), shape=shape), ) for _ in range(num_args)] @experiment("a,b", args) def dummy(a, b): """dummy""" return a + b