def train(self, env, time_remaining=1e9, noise_type=None, noise_value=None): time_start = time.time() sd = 1 if env.has_discrete_state_space() else self.state_dim ad = 1 if env.has_discrete_action_space() else self.action_dim replay_buffer = ReplayBuffer(state_dim=sd, action_dim=ad, device=self.device, max_size=self.rb_size) avg_meter_reward = AverageMeter(print_str="Average reward: ") # training loop for episode in range(self.train_episodes): # early out if timeout if self.time_is_up(avg_meter_reward=avg_meter_reward, max_episodes=self.train_episodes, time_elapsed=time.time() - time_start, time_remaining=time_remaining): break self.update_parameters_per_episode(episode=episode) state = env.reset() episode_reward = 0 for t in range(0, env.max_episode_steps()): action = self.select_train_action(state=state, env=env) # live view if self.render_env and episode % 10 == 0: env.render() # state-action transition next_state, reward, done = env.step(action=action) next_state, reward = self.apply_noise(next_state=next_state, reward=reward, noise_type=noise_type, noise_value=noise_value) replay_buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done) state = next_state episode_reward += reward # train if episode >= self.init_episodes: self.learn(replay_buffer=replay_buffer, env=env) if done > 0.5: break # logging avg_meter_reward.update(episode_reward, print_rate=self.print_rate) # quit training if environment is solved if self.env_solved(env=env, avg_meter_reward=avg_meter_reward, episode=episode): break env.close() return avg_meter_reward.get_raw_data(), replay_buffer
def test(self, env, time_remaining=1e9): discretize_action = False sd = 1 if env.has_discrete_state_space() else self.state_dim if env.has_discrete_action_space(): ad = 1 # in case of td3_discrete, action_dim=1 does not reflect the required action_dim for the gumbel softmax distribution if self.agent_name == "td3_discrete_vary": ad = env.get_action_dim() discretize_action = True else: ad = self.action_dim replay_buffer = ReplayBuffer(state_dim=sd, action_dim=ad, device=self.device, max_size=int(1e6)) env.set_agent_params(same_action_num=self.same_action_num, gamma=self.gamma) with torch.no_grad(): time_start = time.time() avg_meter_reward = AverageMeter(print_str="Average reward: ") avg_meter_episode_length = AverageMeter( print_str="Average episode length: ") # training loop for episode in range(self.test_episodes): # early out if timeout if self.time_is_up( avg_meter_reward=avg_meter_reward, avg_meter_episode_length=avg_meter_episode_length, max_episodes=self.test_episodes, time_elapsed=time.time() - time_start, time_remaining=time_remaining): break state = env.reset() episode_reward = 0 episode_length = 0 for t in range(0, env.max_episode_steps(), self.same_action_num): action = self.select_test_action(state, env) # live view if self.render_env: env.render() # state-action transition # required due to gumble softmax in td3 discrete if discretize_action: next_state, reward, done = env.step( action=action.argmax().unsqueeze(0)) else: next_state, reward, done = env.step(action=action) replay_buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done) state = next_state episode_reward += reward episode_length += 1 if done > 0.5: break # logging avg_meter_episode_length.update(episode_length, print_rate=1e9) avg_meter_reward.update(episode_reward.item(), print_rate=self.print_rate) env.close() # todo: use dict to reduce confusions and bugs return avg_meter_reward.get_raw_data( ), avg_meter_episode_length.get_raw_data(), replay_buffer
def train(self, env, test_env=None, time_remaining=1e9): time_start = time.time() discretize_action = False sd = 1 if env.has_discrete_state_space() else self.state_dim # todo: @fabio use "hasattr" and custom function in derived class (see below) if env.has_discrete_action_space(): ad = 1 # in case of td3_discrete, action_dim=1 does not reflect the required action_dim for the gumbel softmax distribution if "td3_discrete" in self.agent_name: ad = env.get_action_dim() discretize_action = True else: ad = self.action_dim replay_buffer = ReplayBuffer(state_dim=sd, action_dim=ad, device=self.device, max_size=self.rb_size) avg_meter_reward = AverageMeter(print_str="Average reward: ") avg_meter_episode_length = AverageMeter( print_str="Average episode length: ") env.set_agent_params(same_action_num=self.same_action_num, gamma=self.gamma) # training loop for episode in range(self.train_episodes): # early out if timeout if self.time_is_up( avg_meter_reward=avg_meter_reward, avg_meter_episode_length=avg_meter_episode_length, max_episodes=self.train_episodes, time_elapsed=time.time() - time_start, time_remaining=time_remaining): break if hasattr(self, 'update_parameters_per_episode'): self.update_parameters_per_episode(episode=episode) state = env.reset() episode_reward = 0 episode_length = 0 for t in range(0, env.max_episode_steps(), self.same_action_num): action = self.select_train_action(state=state, env=env, episode=episode) # live view if self.render_env: env.render() # state-action transition # required due to gumble softmax in td3 discrete # todo @fabio: move into agent-specific select_train_action, do the same for test if discretize_action: next_state, reward, done = env.step( action=action.argmax().unsqueeze(0)) else: next_state, reward, done = env.step(action=action) replay_buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done) state = next_state episode_reward += reward episode_length += self.same_action_num # train if episode >= self.init_episodes: self.learn(replay_buffer=replay_buffer, env=env, episode=episode) if done > 0.5: break # logging avg_meter_episode_length.update(episode_length, print_rate=1e9) if test_env is not None: avg_reward_test_raw, _, _ = self.test(test_env) avg_meter_reward.update(statistics.mean(avg_reward_test_raw), print_rate=self.print_rate) else: avg_meter_reward.update(episode_reward, print_rate=self.print_rate) # quit training if environment is solved if episode >= self.init_episodes: if test_env is not None: break_env = test_env else: break_env = env if self.env_solved(env=break_env, avg_meter_reward=avg_meter_reward, episode=episode): print('early out after ' + str(episode) + ' episodes') break env.close() # todo: use dict to reduce confusions and bugs return avg_meter_reward.get_raw_data( ), avg_meter_episode_length.get_raw_data(), replay_buffer
def train(self, env, time_remaining=1e9, test_env=None): sd = 1 if env.has_discrete_state_space() else self.state_dim ad = 1 if env.has_discrete_action_space() else self.action_dim replay_buffer = ReplayBuffer(state_dim=sd, action_dim=ad, device=self.device, max_size=self.rb_size) avg_meter_reward = AverageMeter(print_str="Average reward: ") avg_meter_episode_length = AverageMeter(print_str="Average episode length: ") env.set_agent_params(same_action_num=self.same_action_num, gamma=self.gamma) time_step = 0 # training loop for episode in range(self.train_episodes): state = env.reset() episode_reward = 0 episode_length = 0 for t in range(0, env.max_episode_steps(), self.same_action_num): time_step += self.same_action_num # run old policy action = self.actor_old(state.to(self.device)).cpu() next_state, reward, done = env.step(action=action) # live view if self.render_env and episode % 100 == 0: env.render() replay_buffer.add(state=state, action=action, next_state=next_state, reward=reward, done=done) state = next_state episode_reward += reward episode_length += self.same_action_num # train after certain amount of timesteps if time_step / env.max_episode_steps() > self.update_episodes: self.learn(replay_buffer) replay_buffer.clear() time_step = 0 if done > 0.5: break # logging avg_meter_episode_length.update(episode_length, print_rate=1e9) if test_env is not None: avg_reward_test_raw, _, _ = self.test(test_env) avg_meter_reward.update(statistics.mean(avg_reward_test_raw), print_rate=self.print_rate) else: avg_meter_reward.update(episode_reward, print_rate=self.print_rate) # quit training if environment is solved if episode >= self.init_episodes: if test_env is not None: break_env = test_env else: break_env = env if self.env_solved(env=break_env, avg_meter_reward=avg_meter_reward, episode=episode): print('early out after ' + str(episode) + ' episodes') break env.close() return avg_meter_reward.get_raw_data(), avg_meter_episode_length.get_raw_data(), {}