custom_region_available = False
    for key, value in env_conf['useful_region'].items():
        if key in args.env_name:
            env_conf['useful_region'] = value
            custom_region_available = True
            break
    if custom_region_available is not True:
        env_conf['useful_region'] = env_conf['useful_region']['Default']

    print("Using env_conf:", env_conf)
    atari_env = False
    for game in Atari.get_games_list():
        if game in args.env_name.lower():
            atari_env = True
    if atari_env:
        env = Atari.make_env(args.env_name, env_conf)
    else:
        print("Given environment name is not an Atari Env. Creating a Gym env")
        env = gym.make(args.env_name)

    observation_shape = env.observation_space.shape
    action_shape = env.action_space.n
    agent_params = params_manager.get_agent_params()
    agent = Deep_Q_Learner(observation_shape, action_shape, agent_params)

    episode_rewards = list()
    prev_checkpoint_mean_ep_rew = agent.best_mean_reward
    num_improved_episodes_before_checkpoint = 0  # To keep track of the num of ep with higher perf to save model
    print("Using agent_params:", agent_params)
    if agent_params['load_trained_model']:
        try:
    def run(self):
        # If a custom useful_region configuration for this environment ID is available, use it if not use the Default.
        # Currently this is utilized for only the Atari env. Follows the same procedure as in Chapter 6
        custom_region_available = False
        for key, value in self.env_conf['useful_region'].items():
            if key in args.env:
                self.env_conf['useful_region'] = value
                custom_region_available = True
                break
        if custom_region_available is not True:
            self.env_conf['useful_region'] = self.env_conf['useful_region'][
                'Default']
        atari_env = False
        for game in Atari.get_games_list():
            if game in args.env.lower():
                atari_env = True
        if atari_env:  # Use the Atari wrappers (like we did in Chapter 6) if it's an Atari env
            self.env = Atari.make_env(self.env_name, self.env_conf)
        else:
            #print("Given environment name is not an Atari Env. Creating a Gym env")
            self.env = gym.make(self.env_name)

        self.state_shape = self.env.observation_space.shape
        if isinstance(self.env.action_space.sample(),
                      int):  # Discrete action space
            self.action_shape = self.env.action_space.n
            self.policy = self.discrete_policy
            self.continuous_action_space = False

        else:  # Continuous action space
            self.action_shape = self.env.action_space.shape[0]
            self.policy = self.multi_variate_gaussian_policy
        self.critic_shape = 1
        if len(self.state_shape
               ) == 3:  # Screen image is the input to the agent
            if self.continuous_action_space:
                self.actor = DeepActor(self.state_shape, self.action_shape,
                                       device).to(device)
            else:  # Discrete action space
                self.actor = DeepDiscreteActor(self.state_shape,
                                               self.action_shape,
                                               device).to(device)
            self.critic = DeepCritic(self.state_shape, self.critic_shape,
                                     device).to(device)
        else:  # Input is a (single dimensional) vector
            if self.continuous_action_space:
                #self.actor_critic = ShallowActorCritic(self.state_shape, self.action_shape, 1, self.params).to(device)
                self.actor = ShallowActor(self.state_shape, self.action_shape,
                                          device).to(device)
            else:  # Discrete action space
                self.actor = ShallowDiscreteActor(self.state_shape,
                                                  self.action_shape,
                                                  device).to(device)
            self.critic = ShallowCritic(self.state_shape, self.critic_shape,
                                        device).to(device)
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=self.params["learning_rate"])
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=self.params["learning_rate"])

        # Handle loading and saving of trained Agent models
        episode_rewards = list()
        prev_checkpoint_mean_ep_rew = self.best_mean_reward
        num_improved_episodes_before_checkpoint = 0  # To keep track of the num of ep with higher perf to save model
        #print("Using agent_params:", self.params)
        if self.params['load_trained_model']:
            try:
                self.load()
                prev_checkpoint_mean_ep_rew = self.best_mean_reward
            except FileNotFoundError:
                if args.test:  # Test a saved model
                    print(
                        "FATAL: No saved model found. Cannot test. Press any key to train from scratch"
                    )
                    input()
                else:
                    print(
                        "WARNING: No trained model found for this environment. Training from scratch."
                    )

        for episode in range(self.params["max_num_episodes"]):
            obs = self.env.reset()
            done = False
            ep_reward = 0.0
            step_num = 0
            while not done:
                action = self.get_action(obs)
                next_obs, reward, done, _ = self.env.step(action)
                self.rewards.append(reward)
                ep_reward += reward
                step_num += 1
                if not args.test and (
                        step_num >= self.params["learning_step_thresh"]
                        or done):
                    self.learn(next_obs, done)
                    step_num = 0
                    # Monitor performance and save Agent's state when perf improves
                    if done:
                        episode_rewards.append(ep_reward)
                        if ep_reward > self.best_reward:
                            self.best_reward = ep_reward
                        if np.mean(
                                episode_rewards) > prev_checkpoint_mean_ep_rew:
                            num_improved_episodes_before_checkpoint += 1
                        if num_improved_episodes_before_checkpoint >= self.params[
                                "save_freq_when_perf_improves"]:
                            prev_checkpoint_mean_ep_rew = np.mean(
                                episode_rewards)
                            self.best_mean_reward = np.mean(episode_rewards)
                            self.save()
                            num_improved_episodes_before_checkpoint = 0

                obs = next_obs
                self.global_step_num += 1
                if args.render:
                    self.env.render()
                #print(self.actor_name + ":Episode#:", episode, "step#:", step_num, "\t rew=", reward, end="\r")
                writer.add_scalar(self.actor_name + "/reward", reward,
                                  self.global_step_num)
            print(
                "{}:Episode#:{} \t ep_reward:{} \t mean_ep_rew:{}\t best_ep_reward:{}"
                .format(self.actor_name, episode, ep_reward,
                        np.mean(episode_rewards), self.best_reward))
            writer.add_scalar(self.actor_name + "/ep_reward", ep_reward,
                              self.global_step_num)
示例#3
0
    custom_region_available = False
    for key, value in env_conf['useful_region'].items():
        if key in args.env:
            env_conf['useful_region'] = value
            custom_region_available = True
            break
    if custom_region_available is not True:
        env_conf['useful_region'] = env_conf['useful_region']['Default']

    print("Using env_conf:", env_conf)
    atari_env = False
    for game in Atari.get_games_list():
        if game.replace("_", "") in args.env.lower():
            atari_env = True
    if atari_env:
        env = Atari.make_env(args.env, env_conf)
    else:
        print("Given environment name is not an Atari Env. Creating a Gym env")
        env = env_utils.ResizeReshapeFrames(gym.make(args.env))

    if args.record:
        env = gym.wrappers.Monitor(env, args.recording_output_dir, force=True)

    observation_shape = env.observation_space.shape
    action_shape = env.action_space.n
    agent_params = params_manager.get_agent_params()
    agent_params["test"] = args.test
    agent = Deep_Q_Learner(observation_shape, action_shape, agent_params)

    episode_rewards = list()
    prev_checkpoint_mean_ep_rew = agent.best_mean_reward