def train_sqil(env, n=0): venv = gym.make(env) expert_data = make_sa_dataset(env, max_trajs=5) for i in range(n): if isinstance(venv.action_space, Discrete): model = DQN(SQLPolicy, venv, verbose=1, policy_kwargs=dict(net_arch=[64, 64]), learning_starts=1) else: model = SAC('MlpPolicy', venv, verbose=1, policy_kwargs=dict(net_arch=[256, 256]), ent_coef='auto', learning_rate=linear_schedule(7.3e-4), train_freq=64, gradient_steps=64, gamma=0.98, tau=0.02) model.replay_buffer = SQILReplayBuffer(model.buffer_size, model.observation_space, model.action_space, model.device, 1, model.optimize_memory_usage, expert_data=expert_data) mean_rewards = [] std_rewards = [] for train_steps in range(20): if train_steps > 0: if 'Bullet' in env: model.learn(total_timesteps=25000, log_interval=1) else: model.learn(total_timesteps=16384, log_interval=1) mean_reward, std_reward = evaluate_policy(model, model.env, n_eval_episodes=10) mean_rewards.append(mean_reward) std_rewards.append(std_reward) print("{0} Steps: {1}".format(train_steps, mean_reward)) np.savez(os.path.join("learners", env, "sqil_rewards_{0}".format(i)), means=mean_rewards, stds=std_rewards)
plt.show() #env = gym.make('CartPole-v1') env = gym.make('FrozenLake-v0') log_dir = "/tmp/gym/" os.makedirs(log_dir, exist_ok=True) env = Monitor(env, log_dir) model = DQN('MlpPolicy', env, verbose=1, batch_size=32, learning_starts=1000) #prioritized_replay=True model.replay_buffer = TrajReplayBuffer(model.buffer_size, model.observation_space, model.action_space, model.device, trajectory=True, seq_num=1) initial_time = round(time(), 2) model.learn(total_timesteps=int(100000)) mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10, deterministic=True) finish_time = round(time(), 2) total_time = round(finish_time - initial_time, 2) print("this run took total time of {0} seconds".format(total_time)) plot_results(log_dir)
def train_adril(env, n=0, balanced=False): num_trajs = 20 expert_data = make_sa_dataset(env, max_trajs=num_trajs) n_expert = len(expert_data["obs"]) expert_sa = np.concatenate( (expert_data["obs"], np.reshape(expert_data["acts"], (n_expert, -1))), axis=1) for i in range(0, n): venv = AdRILWrapper(gym.make(env)) mean_rewards = [] std_rewards = [] # Create model if isinstance(venv.action_space, Discrete): model = DQN(SQLPolicy, venv, verbose=1, policy_kwargs=dict(net_arch=[64, 64]), learning_starts=1) else: model = SAC('MlpPolicy', venv, verbose=1, policy_kwargs=dict(net_arch=[256, 256]), ent_coef='auto', learning_rate=linear_schedule(7.3e-4), train_freq=64, gradient_steps=64, gamma=0.98, tau=0.02) model.replay_buffer = AdRILReplayBuffer(model.buffer_size, model.observation_space, model.action_space, model.device, 1, model.optimize_memory_usage, expert_data=expert_data, N_expert=num_trajs, balanced=balanced) if not balanced: for j in range(len(expert_sa)): obs = expert_data["obs"][j] act = expert_data["acts"][j] next_obs = expert_data["next_obs"][j] done = expert_data["dones"][j] model.replay_buffer.add(obs, next_obs, act, -1, done) for train_steps in range(400): # Train policy if train_steps > 0: if 'Bullet' in env: model.learn(total_timesteps=1250, log_interval=1000) else: model.learn(total_timesteps=25000, log_interval=1000) if train_steps % 1 == 0: # written to support more complex update schemes model.replay_buffer.set_iter(train_steps) model.replay_buffer.set_n_learner(venv.num_trajs) # Evaluate policy if train_steps % 20 == 0: model.set_env(gym.make(env)) mean_reward, std_reward = evaluate_policy(model, model.env, n_eval_episodes=10) mean_rewards.append(mean_reward) std_rewards.append(std_reward) print("{0} Steps: {1}".format(int(train_steps * 1250), mean_reward)) np.savez(os.path.join("learners", env, "adril_rewards_{0}".format(i)), means=mean_rewards, stds=std_rewards) # Update env if train_steps > 0: if train_steps % 1 == 0: venv.set_iter(train_steps + 1) model.set_env(venv)