import numpy as np import tensorflow.compat.v1 as tf from envs.env import ArmEnv from contextual_policy_search.trajectory import TRAJECTROY as tra tf.disable_v2_behavior() print("import finished") env = ArmEnv() env.reset() ep_reward = 0 kp = np.array([0.007, 0.007, 0.03, 0.0001, 0.0001, 0.005]) kd = np.array([0.01, 0.01, 0.05, 0.0001, 0.0001, 0.005]) K = [kp, kd] trajectory = tra() memory = trajectory.pd_trajectory(env, K)
parser.add_argument("--evaluate_Q_value", default=False) parser.add_argument("--reward_name", default='r_s') parser.add_argument("--seq_len", default=2, type=int) parser.add_argument("--ini_seed", default=1, type=int) # Sets Gym, PyTorch and Numpy seeds parser.add_argument("--seed", default=10, type=int) # Sets Gym, PyTorch and Numpy seeds parser.add_argument("--start_timesteps", default=1e3, type=int) # How many time steps purely random policy is run for parser.add_argument("--eval_freq", default=1e3, type=int) # How often (time steps) we evaluate parser.add_argument("--max_timesteps", default=1e5, type=int) # Max time steps to run environment for parser.add_argument("--expl_noise", default=0.1, type=float) # Std of Gaussian exploration noise parser.add_argument("--state_noise", default=0, type=float) # Std of Gaussian exploration noise parser.add_argument("--batch_size", default=100, type=int) # Batch size for both actor and critic parser.add_argument("--discount", default=0.99, type=float) # Discount factor parser.add_argument("--tau", default=0.005, type=float) # Target network update rate parser.add_argument("--policy_noise", default=0.2, type=float) # Noise added to target policy during critic update parser.add_argument("--noise_clip", default=0.2, type=float) # Range to clip target policy noise parser.add_argument("--policy_freq", default=2, type=int) # Frequency of delayed policy updates parser.add_argument("--max_episode_steps", default=200, type=int) args = parser.parse_args() env = ArmEnv() policy_name_vec = ['TD3_RNN', 'ATD3_RNN'] for policy_name in policy_name_vec: for i in range(0, 5): args.policy_name = policy_name args.seed = i main(env, args)