def test_model(self) -> Tuple[List[float], list]: ray.init(logging_level=logging.INFO, ignore_reinit_error=True) agent = DQNTrainer(self.config, env=custom_env_name) weights = torch.load( self.params.model_dir / "trained_model.pt", map_location=lambda storage, loc: storage, ) agent.set_weights({"default_policy": weights}) rewards = [] longest_screens = [] for i in range(self.params.num_testing_episodes): screens = [] try: logger.info("Iteration: {}", i) state = self.env.reset() done = False cumulative_reward = 0 while not done: action = agent.compute_action(state) state, reward, done, _ = self.env.step(action) screen = self.env.render(mode="rgb_array") screens.append(screen) cumulative_reward += reward time.sleep(0.01) logger.info("Iteration: {}, Reward: {}", i, cumulative_reward) rewards.append(cumulative_reward) except KeyboardInterrupt: logger.info("Testing was interrupted") break if len(screens) > len(longest_screens): longest_screens = screens self.env.close() ray.shutdown() return rewards, longest_screens
def evaluate_model(args): if args.model_path == '': print('Cannot evaluate model, no --model_path set') exit(1) def get_env(): # Simulator env uses a single map, so better for evaluation/testing. # DiscreteWrapper just converts wheel velocities to high level discrete actions. return DiscreteWrapper( simulator.Simulator( map_name=args.map, max_steps=2000, )) # Rather than reuse the env, another one is created later because I can't # figure out how to provide register_env with an object, th register_env('DuckieTown-Simulator', lambda _: get_env()) trainer = DQNTrainer( env="DuckieTown-Simulator", config={ "framework": "torch", "model": { "custom_model": "image-dqn", }, # Dueling off "dueling": False, # No hidden layers "hiddens": [], }, ) trainer.restore(args.model_path) sim_env = get_env() # Standard OpenAI Gym reset/action/step/render loop. # This matches how the `enjoy_reinforcement.py` script works, see: https://git.io/J3js2 done = False observation = sim_env.reset() episode_reward = 0 while not done: action = trainer.compute_action(observation) observation, reward, done, _ = sim_env.step(action) episode_reward += reward sim_env.render() print(f'Episode complete, total reward: {episode_reward}')
class DQNrl(object): def __init__(self, env, env_config, config): self.config = config self.config['env_config'] = env_config self.env = env(env_config) self.agent = DQNTrainer(config=self.config, env=env) def fit(self, checkpoint=None): if checkpoint is None: checkpoint = os.path.join(os.getcwd(), 'data/checkpoint_rl.pkl') for idx in trange(5): result = self.agent.train() LOGGER.warning('result: ', result) if (idx + 1) % 5 == 0: LOGGER.warning('Save checkpoint at: {}'.format(idx + 1)) state = self.agent.save_to_object() with open(checkpoint, 'wb') as fp: pickle.dump(state, fp, protocol=pickle.HIGHEST_PROTOCOL) return result def predict(self, checkpoint=None): if checkpoint is not None: with open(checkpoint, 'rb') as fp: state = pickle.load(fp) self.agent.restore_from_object(state) done = False episode_reward = 0 obs = self.env.reset() actions = [] while not done: action = self.agent.compute_action(obs) actions.append(action) obs, reward, done, info = self.env.step(action) episode_reward += reward results = {'action': actions, 'reward': episode_reward} return results