def test_insert_tensor_col(orig, col): for col_idx in range(orig.shape[1] + 1): # also check appending case result = insert_tensor_col(orig, col_idx, col) # Check number of rows and columns assert orig.shape[0] == result.shape[0] assert orig.shape[1] == result.shape[1] - 1 # Check the values to.testing.assert_allclose(result[:, col_idx], col.squeeze())
def step(self, snapshot_mode: str, meta_info: dict = None): if isinstance(inner_env(self._env), BallOnPlate5DSim): ctrl_gains = to.tensor([ [0.1401, 0, 0, 0, -0.09819, -0.1359, 0, 0.545, 0, 0, 0, -0.01417, -0.04427, 0], [0, 0.1381, 0, 0.2518, 0, 0, -0.2142, 0, 0.5371, 0, 0.03336, 0, 0, -0.1262], [0, 0, 0.1414, 0.0002534, 0, 0, -0.0002152, 0, 0, 0.5318, 0, 0, 0, -0.0001269], [0, -0.479, -0.0004812, 39.24, 0, 0, -15.44, 0, -1.988, -0.001934, 9.466, 0, 0, -13.14], [0.3039, 0, 0, 0, 25.13, 15.66, 0, 1.284, 0, 0, 0, 7.609, 6.296, 0] ]) # Compensate for the mismatching different state definition if self.ball_z_dim_mismatch: ctrl_gains = insert_tensor_col(ctrl_gains, 7, to.zeros((5, 1))) # ball z position ctrl_gains = insert_tensor_col(ctrl_gains, -1, to.zeros((5, 1))) # ball z velocity elif isinstance(inner_env(self._env), QBallBalancerSim): # Since the control module can by tricky to install (recommended using anaconda), we only load it if needed import control # System modeling dp = self._env.domain_param dp['J_eq'] = self._env._J_eq dp['B_eq_v'] = self._env._B_eq_v dp['c_kin'] = self._env._c_kin dp['zeta'] = self._env._zeta dp['A_m'] = self._env._A_m A = np.zeros((self._env.obs_space.flat_dim, self._env.obs_space.flat_dim)) A[:self._env.obs_space.flat_dim//2, self._env.obs_space.flat_dim//2:] = \ np.eye(self._env.obs_space.flat_dim//2) A[4, 4] = -dp['B_eq_v']/dp['J_eq'] A[5, 5] = -dp['B_eq_v']/dp['J_eq'] A[6, 0] = dp['c_kin']*dp['m_ball']*dp['g']*dp['r_ball']**2/dp['zeta'] A[6, 6] = -dp['c_kin']*dp['r_ball']**2/dp['zeta'] A[7, 1] = dp['c_kin']*dp['m_ball']*dp['g']*dp['r_ball']**2/dp['zeta'] A[7, 7] = -dp['c_kin']*dp['r_ball']**2/dp['zeta'] B = np.zeros((self._env.obs_space.flat_dim, self._env.act_space.flat_dim)) B[4, 0] = dp['A_m']/dp['J_eq'] B[5, 1] = dp['A_m']/dp['J_eq'] # C = np.zeros((self._env.obs_space.flat_dim // 2, self._env.obs_space.flat_dim)) # C[:self._env.obs_space.flat_dim // 2, :self._env.obs_space.flat_dim // 2] = # np.eye(self._env.obs_space.flat_dim // 2) # D = np.zeros((self._env.obs_space.flat_dim // 2, self._env.act_space.flat_dim)) # Get the weighting matrices from the environment if not isinstance(self._env.task.rew_fcn, QuadrErrRewFcn): # The environment uses a reward function compatible with the LQR Q = self._env.task.rew_fcn.Q R = self._env.task.rew_fcn.R else: # The environment does not use a reward function compatible with the LQR, apply some fine tuning Q = np.diag([1e2, 1e2, 5e2, 5e2, 1e-2, 1e-2, 5e+0, 5e+0]) R = np.diag([1e-2, 1e-2]) # Solve the continuous time Riccati eq K, _, self.eigvals = control.lqr(A, B, Q, R) # for discrete system pass dt ctrl_gains = to.from_numpy(K).to(to.get_default_dtype()) else: raise pyrado.TypeErr(given=inner_env(self._env), expected_type=[BallOnPlate5DSim, QBallBalancerSim]) # Assign the controller gains self._policy.init_param(-1*ctrl_gains) # in classical control it is u = -K*x; here a = psi(s)*s # Sample rollouts to evaluate the LQR ros = self.sampler.sample() # Logging rets = [ro.undiscounted_return() for ro in ros] self.logger.add_value('max return', np.max(rets), 4) self.logger.add_value('median return', np.median(rets), 4) self.logger.add_value('min return', np.min(rets), 4) self.logger.add_value('avg return', np.mean(rets), 4) self.logger.add_value('std return', np.std(rets), 4) self.logger.add_value('avg rollout len', np.mean([ro.length for ro in ros]), 4) self.logger.add_value('num total samples', self._cnt_samples) self.logger.add_value('min mag policy param', self._policy.param_values[to.argmin(abs(self._policy.param_values))]) self.logger.add_value('max mag policy param', self._policy.param_values[to.argmax(abs(self._policy.param_values))]) # Save snapshot data self.make_snapshot(snapshot_mode, float(np.mean(rets)), meta_info)
def get_lin_ctrl(env: SimEnv, ctrl_type: str, ball_z_dim_mismatch: bool = True) -> LinearPolicy: """ Construct a linear controller specified by its controller gains. Parameters for BallOnPlate5DSim by Markus Lamprecht (clipped gains < 1e-5 to 0). :param env: environment :param ctrl_type: type of the controller: 'lqr', or 'h2' :param ball_z_dim_mismatch: only useful for BallOnPlate5DSim set to True if the given controller dos not have the z component (relative position) of the ball in the state vector, i.e. state is 14-dim instead of 16-dim :return: controller compatible with Pyrado Policy """ from pyrado.environments.rcspysim.ball_on_plate import BallOnPlate5DSim if isinstance(inner_env(env), BallOnPlate5DSim): # Get the controller gains (K-matrix) if ctrl_type.lower() == 'lqr': ctrl_gains = to.tensor([ [0.1401, 0, 0, 0, -0.09819, -0.1359, 0, 0.545, 0, 0, 0, -0.01417, -0.04427, 0], [0, 0.1381, 0, 0.2518, 0, 0, -0.2142, 0, 0.5371, 0, 0.03336, 0, 0, -0.1262], [0, 0, 0.1414, 0.0002534, 0, 0, -0.0002152, 0, 0, 0.5318, 0, 0, 0, -0.0001269], [0, -0.479, -0.0004812, 39.24, 0, 0, -15.44, 0, -1.988, -0.001934, 9.466, 0, 0, -13.14], [0.3039, 0, 0, 0, 25.13, 15.66, 0, 1.284, 0, 0, 0, 7.609, 6.296, 0] ]) elif ctrl_type.lower() == 'h2': ctrl_gains = to.tensor([ [-73.88, -2.318, 39.49, -4.270, 12.25, 0.9779, 0.2564, 35.11, 5.756, 0.8661, -0.9898, 1.421, 3.132, -0.01899], [-24.45, 0.7202, -10.58, 2.445, -0.6957, 2.1619, -0.3966, -61.66, -3.254, 5.356, 0.1908, 12.88, 6.142, -0.3812], [-101.8, -9.011, 64.345, -5.091, 17.83, -2.636, 0.9506, -44.28, 3.206, 37.59, 2.965, -32.65, -21.68, -0.1133], [-59.56, 1.56, -0.5794, 26.54, -2.503, 3.827, -7.534, 9.999, 1.143, -16.96, 8.450, -5.302, 4.620, -10.32], [-107.1, 0.4359, 19.03, -9.601, 20.33, 10.36, 0.2285, -74.98, -2.136, 7.084, -1.240, 62.62, 33.66, 1.790] ]) else: raise pyrado.ValueErr(given=ctrl_type, eq_constraint="'lqr' or 'h2'") # Compensate for the mismatching different state definition if ball_z_dim_mismatch: ctrl_gains = insert_tensor_col(ctrl_gains, 7, to.zeros((5, 1))) # ball z position ctrl_gains = insert_tensor_col(ctrl_gains, -1, to.zeros((5, 1))) # ball z velocity elif isinstance(inner_env(env), QBallBalancerSim): # Get the controller gains (K-matrix) if ctrl_type.lower() == 'pd': # Quanser gains (the original Quanser controller includes action clipping) ctrl_gains = -to.tensor([[-14., 0, -14*3.45, 0, 0, 0, -14*2.11, 0], [0, -14., 0, -14*3.45, 0, 0, 0, -14*2.11]]) elif ctrl_type.lower() == 'lqr': # Since the control module can by tricky to install (recommended using anaconda), we only load it if needed import control # System modeling A = np.zeros((env.obs_space.flat_dim, env.obs_space.flat_dim)) A[:env.obs_space.flat_dim//2, env.obs_space.flat_dim//2:] = np.eye(env.obs_space.flat_dim//2) A[4, 4] = -env.B_eq_v/env.J_eq A[5, 5] = -env.B_eq_v/env.J_eq A[6, 0] = env.c_kin*env.m_ball*env.g*env.r_ball**2/env.zeta A[6, 6] = -env.c_kin*env.r_ball**2/env.zeta A[7, 1] = env.c_kin*env.m_ball*env.g*env.r_ball**2/env.zeta A[7, 7] = -env.c_kin*env.r_ball**2/env.zeta B = np.zeros((env.obs_space.flat_dim, env.act_space.flat_dim)) B[4, 0] = env.A_m/env.J_eq B[5, 1] = env.A_m/env.J_eq # C = np.zeros((env.obs_space.flat_dim // 2, env.obs_space.flat_dim)) # C[:env.obs_space.flat_dim // 2, :env.obs_space.flat_dim // 2] = np.eye(env.obs_space.flat_dim // 2) # D = np.zeros((env.obs_space.flat_dim // 2, env.act_space.flat_dim)) # Get the weighting matrices from the environment Q = env.task.rew_fcn.Q R = env.task.rew_fcn.R # Q = np.diag([1e2, 1e2, 5e2, 5e2, 1e-2, 1e-2, 1e+1, 1e+1]) # Solve the continuous time Riccati eq K, _, _ = control.lqr(A, B, Q, R) # for discrete system pass dt ctrl_gains = to.from_numpy(K).to(to.get_default_dtype()) else: raise pyrado.ValueErr(given=ctrl_type, eq_constraint="'pd', 'lqr'") else: raise pyrado.TypeErr(given=inner_env(env), expected_type=BallOnPlate5DSim) # Reconstruct the controller feats = FeatureStack([identity_feat]) ctrl = LinearPolicy(env.spec, feats) ctrl.init_param(-1*ctrl_gains) # in classical control it is u = -K*x; here a = psi(s)*s return ctrl