def step(self, state, action, reward, next_state, done, info): # dont treat timeout as done equal to True max_steps_reached = info["logs"]["events"].reached_max_episode_steps reset_noise = False if max_steps_reached: done = False reset_noise = True output = {} action = to_2d_action(action) self.memory.add( state=state, action=action, reward=reward, next_state=next_state, done=float(done), social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) self.step_count += 1 if reset_noise: self.reset() if (len(self.memory) > max(self.batch_size, self.warmup) and (self.step_count + 1) % self.update_rate == 0): output = self.learn() self.prev_action = action if not done else np.zeros(self.action_size) return output
def step(self, state, action, reward, next_state, done, info): # dont treat timeout as done equal to True max_steps_reached = info["logs"]["events"].reached_max_episode_steps if max_steps_reached: done = False action = to_2d_action(action) # pass social_vehicle_rep through the network # state['low_dim_states'] = torch.from_numpy(np.float32(np.append(state['low_dim_states'],self.prev_action))).unsqueeze(0) self.log_probs.append(self.current_log_prob.to(self.device)) self.values.append(self.current_value.to(self.device)) self.states.append(state) self.rewards.append(torch.FloatTensor([reward]).to(self.device)) self.actions.append( torch.FloatTensor(action.reshape(self.action_size, )).to( self.device)) self.terminals.append(1.0 - float(done * 1)) output = {} # batch updates over multiple episodes if len(self.terminals) >= self.batch_size: output = self.learn() self.prev_action = action if not done else np.zeros(self.action_size) return output
def step(self, state, action, reward, next_state, done): # dont treat timeout as done equal to True max_steps_reached = state["events"].reached_max_episode_steps if max_steps_reached: done = False action = to_2d_action(action) state = self.state_preprocessor( state=state, normalize=True, device=self.device, social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) # pass social_vehicle_rep through the network self.log_probs.append(self.current_log_prob.to(self.device)) self.values.append(self.current_value.to(self.device)) self.states.append(state) self.rewards.append(torch.FloatTensor([reward]).to(self.device)) self.actions.append( torch.FloatTensor(action.reshape(self.action_size, )).to( self.device)) self.terminals.append(1.0 - float(done * 1)) output = {} # batch updates over multiple episodes if len(self.terminals) >= self.batch_size: output = self.learn() self.prev_action = action if not done else np.zeros(self.action_size) return output
def step(self, state, action, reward, next_state, done): # dont treat timeout as done equal to True max_steps_reached = state["events"].reached_max_episode_steps if max_steps_reached: done = False action = to_2d_action(action) self.memory.add( state=state, action=action, reward=reward, next_state=next_state, done=float(done), social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) self.steps += 1 output = {} if self.steps > max(self.warmup, self.batch_size): states, actions, rewards, next_states, dones, others = self.memory.sample( device=self.device_name) if self.steps % self.critic_update_rate == 0: critic_loss = self.update_critic(states, actions, rewards, next_states, dones) output["loss/critic_loss"] = { "type": "scalar", "data": critic_loss.item(), "freq": 2, } if self.steps % self.policy_update_rate == 0: actor_loss, temp_loss = self.update_actor_temp( states, actions, rewards, next_states, dones) output["loss/actor_loss"] = { "type": "scalar", "data": actor_loss.item(), "freq": self.logging_freq, } output["loss/temp_loss"] = { "type": "scalar", "data": temp_loss.item(), "freq": self.logging_freq, } output["others/alpha"] = { "type": "scalar", "data": self.sac_net.alpha.item(), "freq": self.logging_freq, } self.current_iteration += 1 self.target_soft_update(self.sac_net.critic, self.sac_net.target, self.tau) self.prev_action = action if not done else np.zeros(self.action_size) return output
def step(self, state, action, reward, next_state, done, others=None): # dont treat timeout as done equal to True max_steps_reached = state["events"].reached_max_episode_steps if max_steps_reached: done = False if self.action_space_type == "continuous": action = to_2d_action(action) _action = ( [[e] for e in action] if not self.merge_action_spaces else [action.tolist()] ) action_index = np.asarray( [ action2index[str(e)] for action2index, e in zip(self.action2indexs, _action) ] ) else: action_index = self.lane_actions.index(action) action = action_index self.replay.add( state=state, action=action_index, reward=reward, next_state=next_state, done=done, others=others, social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) if ( self.step_count % self.train_step == 0 and len(self.replay) >= self.batch_size and (self.warmup is None or len(self.replay) >= self.warmup) ): out = self.learn() self.update_count += 1 else: out = {} if self.target_update > 1 and self.step_count % self.target_update == 0: self.update_target_network() elif self.target_update < 1.0: self.soft_update( self.target_q_network, self.online_q_network, self.target_update ) self.step_count += 1 self.prev_action = action return out