Example #1
0
def train_sqil(env, n=0):
    venv = gym.make(env)
    expert_data = make_sa_dataset(env, max_trajs=5)

    for i in range(n):
        if isinstance(venv.action_space, Discrete):
            model = DQN(SQLPolicy,
                        venv,
                        verbose=1,
                        policy_kwargs=dict(net_arch=[64, 64]),
                        learning_starts=1)
        else:
            model = SAC('MlpPolicy',
                        venv,
                        verbose=1,
                        policy_kwargs=dict(net_arch=[256, 256]),
                        ent_coef='auto',
                        learning_rate=linear_schedule(7.3e-4),
                        train_freq=64,
                        gradient_steps=64,
                        gamma=0.98,
                        tau=0.02)

        model.replay_buffer = SQILReplayBuffer(model.buffer_size,
                                               model.observation_space,
                                               model.action_space,
                                               model.device,
                                               1,
                                               model.optimize_memory_usage,
                                               expert_data=expert_data)
        mean_rewards = []
        std_rewards = []
        for train_steps in range(20):
            if train_steps > 0:
                if 'Bullet' in env:
                    model.learn(total_timesteps=25000, log_interval=1)
                else:
                    model.learn(total_timesteps=16384, log_interval=1)
            mean_reward, std_reward = evaluate_policy(model,
                                                      model.env,
                                                      n_eval_episodes=10)
            mean_rewards.append(mean_reward)
            std_rewards.append(std_reward)
            print("{0} Steps: {1}".format(train_steps, mean_reward))
            np.savez(os.path.join("learners", env,
                                  "sqil_rewards_{0}".format(i)),
                     means=mean_rewards,
                     stds=std_rewards)
Example #2
0
    plt.show()


#env = gym.make('CartPole-v1')
env = gym.make('FrozenLake-v0')

log_dir = "/tmp/gym/"
os.makedirs(log_dir, exist_ok=True)
env = Monitor(env, log_dir)

model = DQN('MlpPolicy', env, verbose=1, batch_size=32,
            learning_starts=1000)  #prioritized_replay=True

model.replay_buffer = TrajReplayBuffer(model.buffer_size,
                                       model.observation_space,
                                       model.action_space,
                                       model.device,
                                       trajectory=True,
                                       seq_num=1)
initial_time = round(time(), 2)
model.learn(total_timesteps=int(100000))

mean_reward, std_reward = evaluate_policy(model,
                                          env,
                                          n_eval_episodes=10,
                                          deterministic=True)

finish_time = round(time(), 2)
total_time = round(finish_time - initial_time, 2)
print("this run took total time of {0} seconds".format(total_time))
plot_results(log_dir)
Example #3
0
def train_adril(env, n=0, balanced=False):
    num_trajs = 20
    expert_data = make_sa_dataset(env, max_trajs=num_trajs)
    n_expert = len(expert_data["obs"])
    expert_sa = np.concatenate(
        (expert_data["obs"], np.reshape(expert_data["acts"], (n_expert, -1))),
        axis=1)

    for i in range(0, n):
        venv = AdRILWrapper(gym.make(env))
        mean_rewards = []
        std_rewards = []
        # Create model
        if isinstance(venv.action_space, Discrete):
            model = DQN(SQLPolicy,
                        venv,
                        verbose=1,
                        policy_kwargs=dict(net_arch=[64, 64]),
                        learning_starts=1)
        else:
            model = SAC('MlpPolicy',
                        venv,
                        verbose=1,
                        policy_kwargs=dict(net_arch=[256, 256]),
                        ent_coef='auto',
                        learning_rate=linear_schedule(7.3e-4),
                        train_freq=64,
                        gradient_steps=64,
                        gamma=0.98,
                        tau=0.02)
        model.replay_buffer = AdRILReplayBuffer(model.buffer_size,
                                                model.observation_space,
                                                model.action_space,
                                                model.device,
                                                1,
                                                model.optimize_memory_usage,
                                                expert_data=expert_data,
                                                N_expert=num_trajs,
                                                balanced=balanced)
        if not balanced:
            for j in range(len(expert_sa)):
                obs = expert_data["obs"][j]
                act = expert_data["acts"][j]
                next_obs = expert_data["next_obs"][j]
                done = expert_data["dones"][j]
                model.replay_buffer.add(obs, next_obs, act, -1, done)
        for train_steps in range(400):
            # Train policy
            if train_steps > 0:
                if 'Bullet' in env:
                    model.learn(total_timesteps=1250, log_interval=1000)
                else:
                    model.learn(total_timesteps=25000, log_interval=1000)
                if train_steps % 1 == 0:  # written to support more complex update schemes
                    model.replay_buffer.set_iter(train_steps)
                    model.replay_buffer.set_n_learner(venv.num_trajs)

            # Evaluate policy
            if train_steps % 20 == 0:
                model.set_env(gym.make(env))
                mean_reward, std_reward = evaluate_policy(model,
                                                          model.env,
                                                          n_eval_episodes=10)
                mean_rewards.append(mean_reward)
                std_rewards.append(std_reward)
                print("{0} Steps: {1}".format(int(train_steps * 1250),
                                              mean_reward))
                np.savez(os.path.join("learners", env,
                                      "adril_rewards_{0}".format(i)),
                         means=mean_rewards,
                         stds=std_rewards)
            # Update env
            if train_steps > 0:
                if train_steps % 1 == 0:
                    venv.set_iter(train_steps + 1)
            model.set_env(venv)