def mytest(env_name, eval_episode=10, num_init_traj=1, max_horizon=15, ensemble=1, gt=0, finetune=False, finetune_iter=41, finetune_proc=10, cem_iter=20): NUMS = { 'HalfCheetahPT-v2': 6, 'HopperPT-v2': 5, 'Walker2dPT-v2': 8, } num = NUMS[env_name] if not finetune: policy_net = get_awr_network(env_name, num) else: policy_net = get_finetune_network(env_name, num, num_iter=finetune_iter, num_proc=finetune_proc) model = make_parallel(10, env_name, num=num, stochastic=False) env = make(env_name, num=num, resample_MP=True, stochastic=False) params = get_params(env) mean_params = np.array([0.5] * len(params)) osi = CEMOSI(model, mean_params, iter_num=cem_iter, num_mutation=100, num_elite=10, std=0.3) rewards, dist = online_osi(env, osi, policy_net, num_init_traj=num_init_traj, max_horizon=max_horizon, eval_episodes=eval_episode, use_state=False, print_timestep=10000, resample_MP=True, ensemble=ensemble, online=0, gt=gt) rewards = np.array(rewards) print('l2 distance', dist) print('rewards', rewards) return { 'mean': rewards.mean(), 'std': rewards.std(), 'min': rewards.min(), 'max': rewards.max(), 'dist': dist.mean(), }
def test_model(): env_name = 'DartHopperPT-v1' env = make_parallel(1, env_name, num=2) env2 = make(env_name, num=2, stochastic=False) batch_size = 30 horizon = 100 s = [] for i in range(batch_size): env2.reset() s.append(get_state(env2)) param = get_params(env2) params = np.array([param for i in range(batch_size)]) env2.env.noisy_input = False s = np.array(s) a = [[env2.action_space.sample() for j in range(horizon)] for i in range(batch_size)] a = np.array(a) for i in range(3): obs, _, done, _ = env2.step(a[-1][i]) if done: break for i in tqdm.trange(1): r, obs, mask = env(params, s, a) print(obs[-1][:3])
def test_up_diff(): env_name = 'HopperPT-v2' num = 5 policy_net = get_awr_network(env_name, num) model = make_parallel(30, env_name, num=num, stochastic=False) env = make(env_name, num=num, resample_MP=True, stochastic=False) params = get_params(env) #set_params(env, [0.55111654,0.55281674,0.46355396,0.84531834,0.58944066]) set_params(env, [0.31851129, 0.93941556, 0.02147825, 0.43523052, 1.02611646]) set_params(env, [0.94107358, 0.77519005, 0.44055224, 0.9369426, -0.03846457]) set_params(env, [0.05039606, 0.14680257, 0.56502066, 0.25723492, 0.73810709]) mean_params = np.array([0.5] * len(params)) osi = DiffOSI(model, mean_params, 0.001, iter=100, momentum=0.9, eps=1e-3) policy_net.set_params(mean_params) # I run this at the last time.. # online is very useful .. online_osi(env, osi, policy_net, num_init_traj=5, max_horizon=15, eval_episodes=20, use_state=False, print_timestep=10000, resample_MP=True, online=0)
def test_up_osi(): #env_name = 'DartHopperPT-v1' env_name = 'HopperPT-v2' num = 5 #policy_net = get_up_network(env_name, num) policy_net = get_awr_network(env_name, num) model = make_parallel(10, env_name, num=num, stochastic=False) env = make(env_name, num=num, resample_MP=True, stochastic=False) params = get_params(env) #set_params(env, [0.55111654,0.55281674,0.46355396,0.84531834,0.58944066]) set_params(env, [0.31851129, 0.93941556, 0.02147825, 0.43523052, 1.02611646]) set_params(env, [0.94107358, 0.77519005, 0.44055224, 0.9369426, -0.03846457]) set_params(env, [0.05039606, 0.14680257, 0.56502066, 0.25723492, 0.73810709]) mean_params = np.array([0.5] * len(params)) osi = CEMOSI(model, mean_params, iter_num=20, num_mutation=100, num_elite=10, std=0.3) policy_net.set_params(mean_params) online_osi(env, osi, policy_net, num_init_traj=5, max_horizon=15, eval_episodes=30, use_state=False, print_timestep=10000, resample_MP=True, ensemble=1, online=0, gt=0)
def test_POLO(): #env_name = 'DartWalker2dPT-v1' #num = 8 env_name = 'DartHopperPT-v1' num = 5 value_net = get_td3_value(env_name) #value_net = None parser = argparse.ArgumentParser() add_parser(parser) args = parser.parse_args() model = make_parallel(args.num_proc, env_name, num=num, stochastic=True) env = make(env_name, num=num, resample_MP=False) controller = POLO(value_net, model, action_space=env.action_space, add_actions=args.add_actions, horizon=args.horizon, std=args.std, iter_num=args.iter_num, initial_iter=args.initial_iter, num_mutation=args.num_mutation, num_elite=args.num_elite, alpha=0.1, trunc_norm=True, lower_bound=env.action_space.low, upper_bound=env.action_space.high, replan_period=5) trajectories = eval_policy(controller, env, 10, args.video_num, args.video_path, timestep=args.timestep, set_gt_params=True, print_timestep=100)
def test_cem_osi(): env_name = 'HopperPT-v3' num = 5 from networks import get_td3_value #value_net = get_td3_value(env_name) value_net = None from policy import POLO, add_parser import argparse parser = argparse.ArgumentParser() add_parser(parser) args = parser.parse_args() args.num_proc = 20 model = make_parallel(args.num_proc, env_name, num=num, stochastic=True) env = make(env_name, num=num, resample_MP=True) #args.iter_num = 2 args.num_mutation = 500 #args.num_mutation = 100 args.iter_num = 5 args.num_elite = 10 policy_net = POLO(value_net, model, action_space=env.action_space, add_actions=args.add_actions, horizon=args.horizon, std=args.std, iter_num=args.iter_num, initial_iter=args.initial_iter, num_mutation=args.num_mutation, num_elite=args.num_elite, alpha=0.1, trunc_norm=True, lower_bound=env.action_space.low, upper_bound=env.action_space.high) resample_MP = True env = make(env_name, num=num, resample_MP=resample_MP, stochastic=False) params = get_params(env) print("FIXXXXXXXXXXXXXXXXXXXXXXPARAMETERS") set_params( env, np.array([0.58093299, 0.05418986, 0.93399553, 0.1678795, 1.04150952])) set_params(env, [0.55111654, 0.55281674, 0.46355396, 0.84531834, 0.58944066]) set_params(env, [0.31851129, 0.93941556, 0.02147825, 0.43523052, 1.02611646]) set_params(env, [0.58589476, 0.11078934, 0.348238, 0.68130195, 0.98376274]) mean_params = np.array([0.5] * len(params)) osi = CEMOSI(model, mean_params, iter_num=20, num_mutation=100, num_elite=10, std=0.3, ensemble_num=5) policy_net.set_params(mean_params) print(get_params(env)) online_osi(env, osi, policy_net, num_init_traj=1, max_horizon=15, eval_episodes=10, use_state=True, print_timestep=10, resample_MP=resample_MP, online=0, ensemble=5)