def __init__(self, gravity=9.8, masscart=1.0, masspole=0.1, length=.5, gamma=0.99, tau=.02, init_state_params=None, name="Cart-Pendulum"): if init_state_params is None: init_state = CartPoleState(x=0, x_dot=0, theta=0, theta_dot=0) else: init_state = CartPoleState(x=init_state_params["x"], x_dot=init_state_params["x_dot"],\ theta=init_state_params["theta"], theta_dot=init_state_params["theta_dot"]) MDP.__init__(self, CartPoleMDP.ACTIONS, self._transition_func, self._reward_func, init_state=init_state, gamma=gamma) #from parameters self.gravity = gravity self.masscart = masscart self.masspole = masspole self.length = length self.gamma = gamma self.tau = tau self.name = name #thresholds self.x_threshold = 2.4 #abs val of limit of x position of cart self.theta_threshold = self._degrees_to_radians(20) #angle away from vertical before being considered terminal #computed self.total_mass = (self.masscart + self.masspole) self.polemass_length = (self.masspole * self.length)
def _transition_func(self, state, action): ''' Args: state (State) action (str) Returns (State) ''' x, x_dot, theta, theta_dot = self._transition_helper(state, action) next_state = CartPoleState(x=x, x_dot=x_dot, theta=theta, theta_dot=theta_dot) #check if less than threshold values and if ternminal if not self._is_within_threshold(theta=next_state.theta, x=next_state.x): next_state.set_terminal(True) return next_state
def reset(self, init_state_params=None): ''' Args: init_state_params (dict) ''' if init_state_params is None: self.init_state = copy.deepcopy(self.init_state) else: self.init_state = CartPoleState(x=init_state_params["x"], x_dot=init_state_params["x_dot"],\ theta=init_state_params["theta"], theta_dot=init_state_params["theta_dot"]) self.cur_state = self.init_state