class TD3Policy(Agent): def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params self.action_size = int(policy_params["action_size"]) self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32) self.actor_lr = float(policy_params["actor_lr"]) self.critic_lr = float(policy_params["critic_lr"]) self.critic_wd = float(policy_params["critic_wd"]) self.actor_wd = float(policy_params["actor_wd"]) self.noise_clip = float(policy_params["noise_clip"]) self.policy_noise = float(policy_params["policy_noise"]) self.update_rate = int(policy_params["update_rate"]) self.policy_delay = int(policy_params["policy_delay"]) self.warmup = int(policy_params["warmup"]) self.critic_tau = float(policy_params["critic_tau"]) self.actor_tau = float(policy_params["actor_tau"]) self.gamma = float(policy_params["gamma"]) self.batch_size = int(policy_params["batch_size"]) self.sigma = float(policy_params["sigma"]) self.theta = float(policy_params["theta"]) self.dt = float(policy_params["dt"]) self.action_low = torch.tensor( [[each[0] for each in self.action_range]]) self.action_high = torch.tensor( [[each[1] for each in self.action_range]]) self.seed = int(policy_params["seed"]) self.prev_action = np.zeros(self.action_size) # state preprocessing self.social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0)) self.social_capacity = int(policy_params["social_vehicles"].get( "social_capacity", 0)) self.observation_num_lookahead = int( policy_params.get("observation_num_lookahead", 0)) self.social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) self.num_social_features = int(policy_params["social_vehicles"].get( "num_social_features", 0)) self.social_vehicle_config = get_social_vehicle_configs( **policy_params["social_vehicles"]) self.social_vehicle_encoder = self.social_vehicle_config["encoder"] self.state_description = BaselineStatePreprocessor.get_state_description( policy_params["social_vehicles"], policy_params["observation_num_lookahead"], self.action_size, ) self.social_feature_encoder_class = self.social_vehicle_encoder[ "social_feature_encoder_class"] self.social_feature_encoder_params = self.social_vehicle_encoder[ "social_feature_encoder_params"] # others self.checkpoint_dir = checkpoint_dir self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.save_codes = (policy_params["save_codes"] if "save_codes" in policy_params else None) self.memory = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), device_name=self.device_name, ) self.num_actor_updates = 0 self.current_iteration = 0 self.step_count = 0 self.init_networks() if checkpoint_dir: self.load(checkpoint_dir) @property def state_size(self): # Adjusting state_size based on number of features (ego+social) size = sum(self.state_description["low_dim_states"].values()) if self.social_feature_encoder_class: size += self.social_feature_encoder_class( **self.social_feature_encoder_params).output_dim else: size += self.social_capacity * self.num_social_features # adding the previous action size += self.action_size return size def init_networks(self): self.noise = [ OrnsteinUhlenbeckProcess(size=(1, ), theta=0.01, std=LinearSchedule(0.25), mu=0.0, x0=0.0, dt=1.0), # throttle OrnsteinUhlenbeckProcess(size=(1, ), theta=0.1, std=LinearSchedule(0.05), mu=0.0, x0=0.0, dt=1.0), # steering ] self.actor = ActorNetwork( self.state_size, self.action_size, self.seed, social_feature_encoder=self.social_feature_encoder_class( **self.social_feature_encoder_params) if self.social_feature_encoder_class else None, ).to(self.device) self.actor_target = ActorNetwork( self.state_size, self.action_size, self.seed, social_feature_encoder=self.social_feature_encoder_class( **self.social_feature_encoder_params) if self.social_feature_encoder_class else None, ).to(self.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.actor_lr) self.critic_1 = CriticNetwork( self.state_size, self.action_size, self.seed, social_feature_encoder=self.social_feature_encoder_class( **self.social_feature_encoder_params) if self.social_feature_encoder_class else None, ).to(self.device) self.critic_1_target = CriticNetwork( self.state_size, self.action_size, self.seed, social_feature_encoder=self.social_feature_encoder_class( **self.social_feature_encoder_params) if self.social_feature_encoder_class else None, ).to(self.device) self.critic_1_target.load_state_dict(self.critic_1.state_dict()) self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(), lr=self.critic_lr) self.critic_2 = CriticNetwork( self.state_size, self.action_size, self.seed, social_feature_encoder=self.social_feature_encoder_class( **self.social_feature_encoder_params) if self.social_feature_encoder_class else None, ).to(self.device) self.critic_2_target = CriticNetwork( self.state_size, self.action_size, self.seed, social_feature_encoder=self.social_feature_encoder_class( **self.social_feature_encoder_params) if self.social_feature_encoder_class else None, ).to(self.device) self.critic_2_target.load_state_dict(self.critic_2.state_dict()) self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(), lr=self.critic_lr) def act(self, state, explore=True): state = copy.deepcopy(state) state["low_dim_states"] = np.float32( np.append(state["low_dim_states"], self.prev_action)) state["social_vehicles"] = (torch.from_numpy( state["social_vehicles"]).unsqueeze(0).to(self.device)) state["low_dim_states"] = (torch.from_numpy( state["low_dim_states"]).unsqueeze(0).to(self.device)) self.actor.eval() action = self.actor(state).cpu().data.numpy().flatten() noise = [self.noise[0].sample(), self.noise[1].sample()] if explore: action[0] += noise[0] action[1] += noise[1] self.actor.train() action_low, action_high = ( self.action_low.data.cpu().numpy(), self.action_high.data.cpu().numpy(), ) action = np.clip(action, action_low, action_high)[0] return to_3d_action(action) def step(self, state, action, reward, next_state, done, info): # dont treat timeout as done equal to True max_steps_reached = info["logs"]["events"].reached_max_episode_steps reset_noise = False if max_steps_reached: done = False reset_noise = True output = {} action = to_2d_action(action) self.memory.add( state=state, action=action, reward=reward, next_state=next_state, done=float(done), social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) self.step_count += 1 if reset_noise: self.reset() if (len(self.memory) > max(self.batch_size, self.warmup) and (self.step_count + 1) % self.update_rate == 0): output = self.learn() self.prev_action = action if not done else np.zeros(self.action_size) return output def reset(self): self.noise[0].reset_states() self.noise[1].reset_states() def learn(self): output = {} states, actions, rewards, next_states, dones, others = self.memory.sample( device=self.device) actions = actions.squeeze(dim=1) next_actions = self.actor_target(next_states) noise = torch.randn_like(next_actions).mul(self.policy_noise) noise = noise.clamp(-self.noise_clip, self.noise_clip) next_actions += noise next_actions = torch.max( torch.min(next_actions, self.action_high.to(self.device)), self.action_low.to(self.device), ) target_Q1 = self.critic_1_target(next_states, next_actions) target_Q2 = self.critic_2_target(next_states, next_actions) target_Q = torch.min(target_Q1, target_Q2) target_Q = (rewards + ((1 - dones) * self.gamma * target_Q)).detach() # Optimize Critic 1: current_Q1, aux_losses_Q1 = self.critic_1(states, actions, training=True) loss_Q1 = F.mse_loss(current_Q1, target_Q) + compute_sum_aux_losses(aux_losses_Q1) self.critic_1_optimizer.zero_grad() loss_Q1.backward() self.critic_1_optimizer.step() # Optimize Critic 2: current_Q2, aux_losses_Q2 = self.critic_2(states, actions, training=True) loss_Q2 = F.mse_loss(current_Q2, target_Q) + compute_sum_aux_losses(aux_losses_Q2) self.critic_2_optimizer.zero_grad() loss_Q2.backward() self.critic_2_optimizer.step() # delayed actor updates if (self.step_count + 1) % self.policy_delay == 0: critic_out = self.critic_1(states, self.actor(states), training=True) actor_loss, actor_aux_losses = -critic_out[0], critic_out[1] actor_loss = actor_loss.mean() + compute_sum_aux_losses( actor_aux_losses) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.soft_update(self.actor_target, self.actor, self.actor_tau) self.num_actor_updates += 1 output = { "loss/critic_1": { "type": "scalar", "data": loss_Q1.data.cpu().numpy(), "freq": 10, }, "loss/actor": { "type": "scalar", "data": actor_loss.data.cpu().numpy(), "freq": 10, }, } self.soft_update(self.critic_1_target, self.critic_1, self.critic_tau) self.soft_update(self.critic_2_target, self.critic_2, self.critic_tau) self.current_iteration += 1 return output def soft_update(self, target, src, tau): for target_param, param in zip(target.parameters(), src.parameters()): target_param.detach_() target_param.copy_(target_param * (1.0 - tau) + param * tau) def load(self, model_dir): model_dir = pathlib.Path(model_dir) map_location = None if self.device and self.device.type == "cpu": map_location = "cpu" self.actor.load_state_dict( torch.load(model_dir / "actor.pth", map_location=map_location)) self.actor_target.load_state_dict( torch.load(model_dir / "actor_target.pth", map_location=map_location)) self.critic_1.load_state_dict( torch.load(model_dir / "critic_1.pth", map_location=map_location)) self.critic_1_target.load_state_dict( torch.load(model_dir / "critic_1_target.pth", map_location=map_location)) self.critic_2.load_state_dict( torch.load(model_dir / "critic_2.pth", map_location=map_location)) self.critic_2_target.load_state_dict( torch.load(model_dir / "critic_2_target.pth", map_location=map_location)) def save(self, model_dir): model_dir = pathlib.Path(model_dir) torch.save(self.actor.state_dict(), model_dir / "actor.pth") torch.save( self.actor_target.state_dict(), model_dir / "actor_target.pth", ) torch.save(self.critic_1.state_dict(), model_dir / "critic_1.pth") torch.save( self.critic_1_target.state_dict(), model_dir / "critic_1_target.pth", ) torch.save(self.critic_2.state_dict(), model_dir / "critic_2.pth") torch.save( self.critic_2_target.state_dict(), model_dir / "critic_2_target.pth", )
class DQNPolicy(Agent): lane_actions = [ "keep_lane", "slow_down", "change_lane_left", "change_lane_right" ] def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params network_class = DQNWithSocialEncoder self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000) action_space_type = policy_params["action_space_type"] if action_space_type == "continuous": discrete_action_spaces = [ np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]), np.asarray([ -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75, 1.0 ]), ] else: discrete_action_spaces = [[0], [1]] action_size = discrete_action_spaces self.merge_action_spaces = 0 if action_space_type == "continuous" else -1 self.step_count = 0 self.update_count = 0 self.num_updates = 0 self.current_sticky = 0 self.current_iteration = 0 lr = float(policy_params["lr"]) seed = int(policy_params["seed"]) self.train_step = int(policy_params["train_step"]) self.target_update = float(policy_params["target_update"]) self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.warmup = int(policy_params["warmup"]) self.gamma = float(policy_params["gamma"]) self.batch_size = int(policy_params["batch_size"]) self.use_ddqn = policy_params["use_ddqn"] self.sticky_actions = int(policy_params["sticky_actions"]) prev_action_size = int(policy_params["prev_action_size"]) self.prev_action = np.zeros(prev_action_size) if self.merge_action_spaces == 1: index2action, action2index = merge_discrete_action_spaces( *action_size) self.index2actions = [index2action] self.action2indexs = [action2index] self.num_actions = [len(self.index2actions)] elif self.merge_action_spaces == 0: self.index2actions = [ merge_discrete_action_spaces([each])[0] for each in action_size ] self.action2indexs = [ merge_discrete_action_spaces([each])[1] for each in action_size ] self.num_actions = [len(e) for e in action_size] else: index_to_actions = [ e.tolist() if not isinstance(e, list) else e for e in action_size ] action_to_indexs = { str(k): v for k, v in zip( index_to_actions, np.arange(len(index_to_actions)).astype(np.int)) } self.index2actions, self.action2indexs = ( [index_to_actions], [action_to_indexs], ) self.num_actions = [len(index_to_actions)] # state preprocessing self.social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0)) self.social_capacity = int(policy_params["social_vehicles"].get( "social_capacity", 0)) self.observation_num_lookahead = int( policy_params.get("observation_num_lookahead", 0)) self.social_polciy_init_std = int(policy_params["social_vehicles"].get( "social_polciy_init_std", 0)) self.num_social_features = int(policy_params["social_vehicles"].get( "num_social_features", 0)) self.social_vehicle_config = get_social_vehicle_configs( **policy_params["social_vehicles"]) self.social_vehicle_encoder = self.social_vehicle_config["encoder"] self.state_description = get_state_description( policy_params["social_vehicles"], policy_params["observation_num_lookahead"], prev_action_size, ) self.social_feature_encoder_class = self.social_vehicle_encoder[ "social_feature_encoder_class"] self.social_feature_encoder_params = self.social_vehicle_encoder[ "social_feature_encoder_params"] self.checkpoint_dir = checkpoint_dir self.reset() torch.manual_seed(seed) network_params = { "state_size": self.state_size, "social_feature_encoder_class": self.social_feature_encoder_class, "social_feature_encoder_params": self.social_feature_encoder_params, } self.online_q_network = network_class( num_actions=self.num_actions, **(network_params if network_params else {}), ).to(self.device) self.target_q_network = network_class( num_actions=self.num_actions, **(network_params if network_params else {}), ).to(self.device) self.update_target_network() self.optimizers = torch.optim.Adam( params=self.online_q_network.parameters(), lr=lr) self.loss_func = nn.MSELoss(reduction="none") if self.checkpoint_dir: self.load(self.checkpoint_dir) self.action_space_type = "continuous" self.to_real_action = to_3d_action self.state_preprocessor = StatePreprocessor(preprocess_state, to_2d_action, self.state_description) self.replay = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), state_preprocessor=self.state_preprocessor, device_name=self.device_name, ) def lane_action_to_index(self, state): state = state.copy() if (len(state["action"]) == 3 and (state["action"] == np.asarray( [0, 0, 0])).all()): # initial action state["action"] = np.asarray([0]) else: state["action"] = self.lane_actions.index(state["action"]) return state @property def state_size(self): # Adjusting state_size based on number of features (ego+social) size = sum(self.state_description["low_dim_states"].values()) if self.social_feature_encoder_class: size += self.social_feature_encoder_class( **self.social_feature_encoder_params).output_dim else: size += self.social_capacity * self.num_social_features return size def reset(self): self.eps_throttles = [] self.eps_steers = [] self.eps_step = 0 self.current_sticky = 0 def soft_update(self, target, src, tau): for target_param, param in zip(target.parameters(), src.parameters()): target_param.detach_() target_param.copy_(target_param * (1.0 - tau) + param * tau) def update_target_network(self): self.target_q_network.load_state_dict( self.online_q_network.state_dict().copy()) def act(self, *args, **kwargs): if self.current_sticky == 0: self.action = self._act(*args, **kwargs) self.current_sticky = (self.current_sticky + 1) % self.sticky_actions self.current_iteration += 1 return self.to_real_action(self.action) def _act(self, state, explore=True): epsilon = self.epsilon_obj.get_epsilon() if not explore or np.random.rand() > epsilon: state = self.state_preprocessor( state, normalize=True, unsqueeze=True, device=self.device, social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) self.online_q_network.eval() with torch.no_grad(): qs = self.online_q_network(state) qs = [q.data.cpu().numpy().flatten() for q in qs] # out_str = " || ".join( # [ # " ".join( # [ # "{}: {:.4f}".format(index2action[j], q[j]) # for j in range(num_action) # ] # ) # for index2action, q, num_action in zip( # self.index2actions, qs, self.num_actions # ) # ] # ) # print(out_str) inds = [np.argmax(q) for q in qs] else: inds = [ np.random.randint(num_action) for num_action in self.num_actions ] action = [] for j, ind in enumerate(inds): action.extend(self.index2actions[j][ind]) self.epsilon_obj.step() self.eps_step += 1 action = np.asarray(action) return action def save(self, model_dir): model_dir = pathlib.Path(model_dir) torch.save(self.online_q_network.state_dict(), model_dir / "online.pth") torch.save(self.target_q_network.state_dict(), model_dir / "target.pth") def load(self, model_dir, cpu=False): model_dir = pathlib.Path(model_dir) print("loading from :", model_dir) map_location = None if cpu: map_location = torch.device("cpu") self.online_q_network.load_state_dict( torch.load(model_dir / "online.pth", map_location=map_location)) self.target_q_network.load_state_dict( torch.load(model_dir / "target.pth", map_location=map_location)) print("Model loaded") def step(self, state, action, reward, next_state, done, others=None): # dont treat timeout as done equal to True max_steps_reached = state["events"].reached_max_episode_steps if max_steps_reached: done = False if self.action_space_type == "continuous": action = to_2d_action(action) _action = ([[e] for e in action] if not self.merge_action_spaces else [action.tolist()]) action_index = np.asarray([ action2index[str(e)] for action2index, e in zip(self.action2indexs, _action) ]) else: action_index = self.lane_actions.index(action) action = action_index self.replay.add( state=state, action=action_index, reward=reward, next_state=next_state, done=done, others=others, social_capacity=self.social_capacity, observation_num_lookahead=self.observation_num_lookahead, social_vehicle_config=self.social_vehicle_config, prev_action=self.prev_action, ) if (self.step_count % self.train_step == 0 and len(self.replay) >= self.batch_size and (self.warmup is None or len(self.replay) >= self.warmup)): out = self.learn() self.update_count += 1 else: out = {} if self.target_update > 1 and self.step_count % self.target_update == 0: self.update_target_network() elif self.target_update < 1.0: self.soft_update(self.target_q_network, self.online_q_network, self.target_update) self.step_count += 1 self.prev_action = action return out def learn(self): states, actions, rewards, next_states, dones, others = self.replay.sample( device=self.device) if not self.merge_action_spaces: actions = torch.chunk(actions, len(self.num_actions), -1) else: actions = [actions] self.target_q_network.eval() with torch.no_grad(): qs_next_target = self.target_q_network(next_states) if self.use_ddqn: self.online_q_network.eval() with torch.no_grad(): qs_next_online = self.online_q_network(next_states) next_actions = [ torch.argmax(q_next_online, dim=1, keepdim=True) for q_next_online in qs_next_online ] else: next_actions = [ torch.argmax(q_next_target, dim=1, keepdim=True) for q_next_target in qs_next_target ] qs_next_target = [ torch.gather(q_next_target, 1, next_action) for q_next_target, next_action in zip(qs_next_target, next_actions) ] self.online_q_network.train() qs, aux_losses = self.online_q_network(states, training=True) qs = [ torch.gather(q, 1, action.long()) for q, action in zip(qs, actions) ] qs_target_value = [ rewards + self.gamma * (1 - dones) * q_next_target for q_next_target in qs_next_target ] td_loss = [ self.loss_func(q, q_target_value).mean() for q, q_target_value in zip(qs, qs_target_value) ] mean_td_loss = sum(td_loss) / len(td_loss) loss = mean_td_loss + sum( [e["value"] * e["weight"] for e in aux_losses.values()]) self.optimizers.zero_grad() loss.backward() self.optimizers.step() out = {} out.update({ "loss/td{}".format(j): { "type": "scalar", "data": td_loss[j].data.cpu().numpy(), "freq": 10, } for j in range(len(td_loss)) }) out.update({ "loss/{}".format(k): { "type": "scalar", "data": v["value"], # .detach().cpu().numpy(), "freq": 10, } for k, v in aux_losses.items() }) out.update({"loss/all": {"type": "scalar", "data": loss, "freq": 10}}) self.num_updates += 1 return out
class DQNPolicy(Agent): lane_actions = [ "keep_lane", "slow_down", "change_lane_left", "change_lane_right" ] def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params self.lr = float(policy_params["lr"]) self.seed = int(policy_params["seed"]) self.train_step = int(policy_params["train_step"]) self.target_update = float(policy_params["target_update"]) self.warmup = int(policy_params["warmup"]) self.gamma = float(policy_params["gamma"]) self.batch_size = int(policy_params["batch_size"]) self.use_ddqn = policy_params["use_ddqn"] self.sticky_actions = int(policy_params["sticky_actions"]) self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000) self.step_count = 0 self.update_count = 0 self.num_updates = 0 self.current_sticky = 0 self.current_iteration = 0 self.action_type = adapters.type_from_string( policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"]) self.reward_type = adapters.type_from_string( policy_params["reward_type"]) if self.action_type == adapters.AdapterType.DefaultActionContinuous: discrete_action_spaces = [ np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]), np.asarray([ -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75, 1.0 ]), ] self.index2actions = [ merge_discrete_action_spaces([discrete_action_space])[0] for discrete_action_space in discrete_action_spaces ] self.action2indexs = [ merge_discrete_action_spaces([discrete_action_space])[1] for discrete_action_space in discrete_action_spaces ] self.merge_action_spaces = 0 self.num_actions = [ len(discrete_action_space) for discrete_action_space in discrete_action_spaces ] self.action_size = 2 self.to_real_action = to_3d_action elif self.action_type == adapters.AdapterType.DefaultActionDiscrete: discrete_action_spaces = [[0], [1], [2], [3]] index_to_actions = [ discrete_action_space.tolist() if not isinstance(discrete_action_space, list) else discrete_action_space for discrete_action_space in discrete_action_spaces ] action_to_indexs = { str(discrete_action): index for discrete_action, index in zip( index_to_actions, np.arange(len(index_to_actions)).astype(np.int)) } self.index2actions = [index_to_actions] self.action2indexs = [action_to_indexs] self.merge_action_spaces = -1 self.num_actions = [len(index_to_actions)] self.action_size = 1 self.to_real_action = lambda action: self.lane_actions[action[0]] else: raise Exception( f"DQN baseline does not support the '{self.action_type}' action type." ) if self.observation_type == adapters.AdapterType.DefaultObservationVector: observation_space = adapters.space_from_type(self.observation_type) low_dim_states_size = observation_space["low_dim_states"].shape[0] social_capacity = observation_space["social_vehicles"].shape[0] num_social_features = observation_space["social_vehicles"].shape[1] # Get information to build the encoder. encoder_key = policy_params["social_vehicles"]["encoder_key"] social_policy_hidden_units = int( policy_params["social_vehicles"].get( "social_policy_hidden_units", 0)) social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) social_vehicle_config = get_social_vehicle_configs( encoder_key=encoder_key, num_social_features=num_social_features, social_capacity=social_capacity, seed=self.seed, social_policy_hidden_units=social_policy_hidden_units, social_policy_init_std=social_policy_init_std, ) social_vehicle_encoder = social_vehicle_config["encoder"] social_feature_encoder_class = social_vehicle_encoder[ "social_feature_encoder_class"] social_feature_encoder_params = social_vehicle_encoder[ "social_feature_encoder_params"] # Calculate the state size based on the number of features (ego + social). state_size = low_dim_states_size if social_feature_encoder_class: state_size += social_feature_encoder_class( **social_feature_encoder_params).output_dim else: state_size += social_capacity * num_social_features # Add the action size to account for the previous action. state_size += self.action_size network_class = DQNWithSocialEncoder network_params = { "num_actions": self.num_actions, "state_size": state_size, "social_feature_encoder_class": social_feature_encoder_class, "social_feature_encoder_params": social_feature_encoder_params, } elif self.observation_type == adapters.AdapterType.DefaultObservationImage: observation_space = adapters.space_from_type(self.observation_type) stack_size = observation_space.shape[0] image_shape = (observation_space.shape[1], observation_space.shape[2]) network_class = DQNCNN network_params = { "n_in_channels": stack_size, "image_dim": image_shape, "num_actions": self.num_actions, } else: raise Exception( f"DQN baseline does not support the '{self.observation_type}' " f"observation type.") self.prev_action = np.zeros(self.action_size) self.checkpoint_dir = checkpoint_dir torch.manual_seed(self.seed) self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.online_q_network = network_class(**network_params).to(self.device) self.target_q_network = network_class(**network_params).to(self.device) self.update_target_network() self.optimizers = torch.optim.Adam( params=self.online_q_network.parameters(), lr=self.lr) self.loss_func = nn.MSELoss(reduction="none") self.replay = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), observation_type=self.observation_type, device_name=self.device_name, ) self.reset() if self.checkpoint_dir: self.load(self.checkpoint_dir) def lane_action_to_index(self, state): state = state.copy() if (len(state["action"]) == 3 and (state["action"] == np.asarray( [0, 0, 0])).all()): # initial action state["action"] = np.asarray([0]) else: state["action"] = self.lane_actions.index(state["action"]) return state def reset(self): self.eps_throttles = [] self.eps_steers = [] self.eps_step = 0 self.current_sticky = 0 def soft_update(self, target, src, tau): for target_param, param in zip(target.parameters(), src.parameters()): target_param.detach_() target_param.copy_(target_param * (1.0 - tau) + param * tau) def update_target_network(self): self.target_q_network.load_state_dict( self.online_q_network.state_dict().copy()) def act(self, *args, **kwargs): if self.current_sticky == 0: self.action = self._act(*args, **kwargs) self.current_sticky = (self.current_sticky + 1) % self.sticky_actions self.current_iteration += 1 return self.to_real_action(self.action) def _act(self, state, explore=True): epsilon = self.epsilon_obj.get_epsilon() if not explore or np.random.rand() > epsilon: state = copy.deepcopy(state) if self.observation_type == adapters.AdapterType.DefaultObservationVector: # Default vector observation type. state["low_dim_states"] = np.float32( np.append(state["low_dim_states"], self.prev_action)) state["social_vehicles"] = (torch.from_numpy( state["social_vehicles"]).unsqueeze(0).to(self.device)) state["low_dim_states"] = (torch.from_numpy( state["low_dim_states"]).unsqueeze(0).to(self.device)) else: # Default image observation type. state = torch.from_numpy(state).unsqueeze(0).to(self.device) self.online_q_network.eval() with torch.no_grad(): qs = self.online_q_network(state) qs = [q.data.cpu().numpy().flatten() for q in qs] # out_str = " || ".join( # [ # " ".join( # [ # "{}: {:.4f}".format(index2action[j], q[j]) # for j in range(num_action) # ] # ) # for index2action, q, num_action in zip( # self.index2actions, qs, self.num_actions # ) # ] # ) # print(out_str) inds = [np.argmax(q) for q in qs] else: inds = [ np.random.randint(num_action) for num_action in self.num_actions ] action = [] for j, ind in enumerate(inds): action.extend(self.index2actions[j][ind]) self.epsilon_obj.step() self.eps_step += 1 action = np.asarray(action) return action def save(self, model_dir): model_dir = pathlib.Path(model_dir) torch.save(self.online_q_network.state_dict(), model_dir / "online.pth") torch.save(self.target_q_network.state_dict(), model_dir / "target.pth") def load(self, model_dir, cpu=False): model_dir = pathlib.Path(model_dir) print("loading from :", model_dir) map_location = None if cpu: map_location = torch.device("cpu") self.online_q_network.load_state_dict( torch.load(model_dir / "online.pth", map_location=map_location)) self.target_q_network.load_state_dict( torch.load(model_dir / "target.pth", map_location=map_location)) print("Model loaded") def step(self, state, action, reward, next_state, done, info, others=None): # dont treat timeout as done equal to True max_steps_reached = info["logs"]["events"].reached_max_episode_steps if max_steps_reached: done = False if self.action_type == adapters.AdapterType.DefaultActionContinuous: action = to_2d_action(action) _action = ([[e] for e in action] if not self.merge_action_spaces else [action.tolist()]) action_index = np.asarray([ action2index[str(e)] for action2index, e in zip(self.action2indexs, _action) ]) else: action_index = self.lane_actions.index(action) action = action_index self.replay.add( state=state, action=action_index, reward=reward, next_state=next_state, done=done, others=others, prev_action=self.prev_action, ) if (self.step_count % self.train_step == 0 and len(self.replay) >= self.batch_size and (self.warmup is None or len(self.replay) >= self.warmup)): out = self.learn() self.update_count += 1 else: out = {} if self.target_update > 1 and self.step_count % self.target_update == 0: self.update_target_network() elif self.target_update < 1.0: self.soft_update(self.target_q_network, self.online_q_network, self.target_update) self.step_count += 1 self.prev_action = action return out def learn(self): states, actions, rewards, next_states, dones, others = self.replay.sample( device=self.device) if not self.merge_action_spaces: actions = torch.chunk(actions, len(self.num_actions), -1) else: actions = [actions] self.target_q_network.eval() with torch.no_grad(): qs_next_target = self.target_q_network(next_states) if self.use_ddqn: self.online_q_network.eval() with torch.no_grad(): qs_next_online = self.online_q_network(next_states) next_actions = [ torch.argmax(q_next_online, dim=1, keepdim=True) for q_next_online in qs_next_online ] else: next_actions = [ torch.argmax(q_next_target, dim=1, keepdim=True) for q_next_target in qs_next_target ] qs_next_target = [ torch.gather(q_next_target, 1, next_action) for q_next_target, next_action in zip(qs_next_target, next_actions) ] self.online_q_network.train() qs, aux_losses = self.online_q_network(states, training=True) qs = [ torch.gather(q, 1, action.long()) for q, action in zip(qs, actions) ] qs_target_value = [ rewards + self.gamma * (1 - dones) * q_next_target for q_next_target in qs_next_target ] td_loss = [ self.loss_func(q, q_target_value).mean() for q, q_target_value in zip(qs, qs_target_value) ] mean_td_loss = sum(td_loss) / len(td_loss) loss = mean_td_loss + sum( [e["value"] * e["weight"] for e in aux_losses.values()]) self.optimizers.zero_grad() loss.backward() self.optimizers.step() out = {} out.update({ "loss/td{}".format(j): { "type": "scalar", "data": td_loss[j].data.cpu().numpy(), "freq": 10, } for j in range(len(td_loss)) }) out.update({ "loss/{}".format(k): { "type": "scalar", "data": v["value"], # .detach().cpu().numpy(), "freq": 10, } for k, v in aux_losses.items() }) out.update({"loss/all": {"type": "scalar", "data": loss, "freq": 10}}) self.num_updates += 1 return out
def test_image_replay_buffer(self): TRANSITIONS = 1024 # The number of transitions to save in the replay buffer. STACK_SIZE = 4 # The stack size of the images. ACTION_SIZE = 3 # The size of each action. IMAGE_WIDTH = 64 # The width of each image. IMAGE_HEIGHT = 64 # The height of each image. BUFFER_SIZE = 1024 # The size of the replay buffer. BATCH_SIZE = 128 # Batch size of each sample from the replay buffer. NUM_SAMPLES = 10 # Number of times to sample from the replay buffer. replay_buffer = ReplayBuffer( buffer_size=BUFFER_SIZE, batch_size=BATCH_SIZE, observation_type=adapters.AdapterType.DefaultObservationImage, device_name="cpu", ) ( states, next_states, previous_actions, actions, rewards, dones, ) = generate_image_transitions(TRANSITIONS, STACK_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, ACTION_SIZE) for state, next_state, action, previous_action, reward, done in zip( states, next_states, actions, previous_actions, rewards, dones): replay_buffer.add( state=state, next_state=next_state, action=action, prev_action=previous_action, reward=reward, done=done, ) for _ in range(NUM_SAMPLES): sample = replay_buffer.sample() for state, action, reward, next_state, done, _ in zip(*sample): state = state.numpy() action = action.numpy() reward = reward.numpy()[0] next_state = next_state.numpy() done = True if done.numpy()[0] else False index_of_state = None for index, original_state in enumerate(states): if np.array_equal(original_state, state): index_of_state = index break self.assertIn(state, states) self.assertIn(next_state, next_states) self.assertIn(action, actions) self.assertIn(reward, rewards) self.assertIn(done, dones) self.assertTrue(np.array_equal(state, states[index_of_state])) self.assertTrue( np.array_equal(next_state, next_states[index_of_state])) self.assertTrue(np.array_equal(action, actions[index_of_state])) self.assertEqual(reward, rewards[index_of_state]) self.assertEqual(done, dones[index_of_state])
class SACPolicy(Agent): def __init__( self, policy_params=None, checkpoint_dir=None, ): # print("LOADING THE PARAMS", policy_params, checkpoint_dir) self.policy_params = policy_params self.gamma = float(policy_params["gamma"]) self.critic_lr = float(policy_params["critic_lr"]) self.actor_lr = float(policy_params["actor_lr"]) self.critic_update_rate = int(policy_params["critic_update_rate"]) self.policy_update_rate = int(policy_params["policy_update_rate"]) self.warmup = int(policy_params["warmup"]) self.seed = int(policy_params["seed"]) self.batch_size = int(policy_params["batch_size"]) self.hidden_units = int(policy_params["hidden_units"]) self.tau = float(policy_params["tau"]) self.initial_alpha = float(policy_params["initial_alpha"]) self.logging_freq = int(policy_params["logging_freq"]) self.action_size = 2 self.prev_action = np.zeros(self.action_size) self.action_type = adapters.type_from_string( policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"]) self.reward_type = adapters.type_from_string( policy_params["reward_type"]) if self.action_type != adapters.AdapterType.DefaultActionContinuous: raise Exception( f"SAC baseline only supports the " f"{adapters.AdapterType.DefaultActionContinuous} action type.") if self.observation_type != adapters.AdapterType.DefaultObservationVector: raise Exception( f"SAC baseline only supports the " f"{adapters.AdapterType.DefaultObservationVector} observation type." ) self.observation_space = adapters.space_from_type( self.observation_type) self.low_dim_states_size = self.observation_space[ "low_dim_states"].shape[0] self.social_capacity = self.observation_space["social_vehicles"].shape[ 0] self.num_social_features = self.observation_space[ "social_vehicles"].shape[1] self.encoder_key = policy_params["social_vehicles"]["encoder_key"] self.social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0)) self.social_policy_init_std = int(policy_params["social_vehicles"].get( "social_policy_init_std", 0)) self.social_vehicle_config = get_social_vehicle_configs( encoder_key=self.encoder_key, num_social_features=self.num_social_features, social_capacity=self.social_capacity, seed=self.seed, social_policy_hidden_units=self.social_policy_hidden_units, social_policy_init_std=self.social_policy_init_std, ) self.social_vehicle_encoder = self.social_vehicle_config["encoder"] self.social_feature_encoder_class = self.social_vehicle_encoder[ "social_feature_encoder_class"] self.social_feature_encoder_params = self.social_vehicle_encoder[ "social_feature_encoder_params"] # others self.checkpoint_dir = checkpoint_dir self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.save_codes = (policy_params["save_codes"] if "save_codes" in policy_params else None) self.memory = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), observation_type=self.observation_type, device_name=self.device_name, ) self.current_iteration = 0 self.steps = 0 self.init_networks() if checkpoint_dir: self.load(checkpoint_dir) @property def state_size(self): # Adjusting state_size based on number of features (ego+social) size = self.low_dim_states_size if self.social_feature_encoder_class: size += self.social_feature_encoder_class( **self.social_feature_encoder_params).output_dim else: size += self.social_capacity * self.num_social_features # adding the previous action size += self.action_size return size def init_networks(self): self.sac_net = SACNetwork( action_size=self.action_size, state_size=self.state_size, hidden_units=self.hidden_units, seed=self.seed, initial_alpha=self.initial_alpha, social_feature_encoder_class=self.social_feature_encoder_class, social_feature_encoder_params=self.social_feature_encoder_params, ).to(self.device_name) self.actor_optimizer = torch.optim.Adam( self.sac_net.actor.parameters(), lr=self.actor_lr) self.critic_optimizer = torch.optim.Adam( self.sac_net.critic.parameters(), lr=self.critic_lr) self.log_alpha_optimizer = torch.optim.Adam([self.sac_net.log_alpha], lr=self.critic_lr) def act(self, state, explore=True): state = copy.deepcopy(state) state["low_dim_states"] = np.float32( np.append(state["low_dim_states"], self.prev_action)) state["social_vehicles"] = (torch.from_numpy( state["social_vehicles"]).unsqueeze(0).to(self.device)) state["low_dim_states"] = (torch.from_numpy( state["low_dim_states"]).unsqueeze(0).to(self.device)) action, _, mean = self.sac_net.sample(state) if explore: # training mode action = torch.squeeze(action, 0) action = action.detach().cpu().numpy() else: # testing mode mean = torch.squeeze(mean, 0) action = mean.detach().cpu().numpy() return to_3d_action(action) def step(self, state, action, reward, next_state, done, info): # dont treat timeout as done equal to True max_steps_reached = info["logs"]["events"].reached_max_episode_steps if max_steps_reached: done = False action = to_2d_action(action) self.memory.add( state=state, action=action, reward=reward, next_state=next_state, done=float(done), prev_action=self.prev_action, ) self.steps += 1 output = {} if self.steps > max(self.warmup, self.batch_size): states, actions, rewards, next_states, dones, others = self.memory.sample( device=self.device_name) if self.steps % self.critic_update_rate == 0: critic_loss = self.update_critic(states, actions, rewards, next_states, dones) output["loss/critic_loss"] = { "type": "scalar", "data": critic_loss.item(), "freq": 2, } if self.steps % self.policy_update_rate == 0: actor_loss, temp_loss = self.update_actor_temp( states, actions, rewards, next_states, dones) output["loss/actor_loss"] = { "type": "scalar", "data": actor_loss.item(), "freq": self.logging_freq, } output["loss/temp_loss"] = { "type": "scalar", "data": temp_loss.item(), "freq": self.logging_freq, } output["others/alpha"] = { "type": "scalar", "data": self.sac_net.alpha.item(), "freq": self.logging_freq, } self.current_iteration += 1 self.target_soft_update(self.sac_net.critic, self.sac_net.target, self.tau) self.prev_action = action if not done else np.zeros(self.action_size) return output def update_critic(self, states, actions, rewards, next_states, dones): q1_current, q2_current, aux_losses = self.sac_net.critic(states, actions, training=True) with torch.no_grad(): next_actions, log_probs, _ = self.sac_net.sample(next_states) q1_next, q2_next = self.sac_net.target(next_states, next_actions) v_next = (torch.min(q1_next, q2_next) - self.sac_net.alpha.detach() * log_probs) q_target = (rewards + ((1 - dones) * self.gamma * v_next)).detach() critic_loss = F.mse_loss(q1_current, q_target) + F.mse_loss( q2_current, q_target) aux_losses = compute_sum_aux_losses(aux_losses) overall_loss = critic_loss + aux_losses self.critic_optimizer.zero_grad() overall_loss.backward() self.critic_optimizer.step() return critic_loss def update_actor_temp(self, states, actions, rewards, next_states, dones): for p in self.sac_net.target.parameters(): p.requires_grad = False for p in self.sac_net.critic.parameters(): p.requires_grad = False # update actor: actions, log_probs, aux_losses = self.sac_net.sample(states, training=True) q1, q2 = self.sac_net.critic(states, actions) q_old = torch.min(q1, q2) actor_loss = (self.sac_net.alpha.detach() * log_probs - q_old).mean() aux_losses = compute_sum_aux_losses(aux_losses) overall_loss = actor_loss + aux_losses self.actor_optimizer.zero_grad() overall_loss.backward() self.actor_optimizer.step() # update temp: temp_loss = (self.sac_net.log_alpha.exp() * (-log_probs.detach().mean() + self.action_size).detach()) self.log_alpha_optimizer.zero_grad() temp_loss.backward() self.log_alpha_optimizer.step() self.sac_net.alpha.data = self.sac_net.log_alpha.exp().detach() for p in self.sac_net.target.parameters(): p.requires_grad = True for p in self.sac_net.critic.parameters(): p.requires_grad = True return actor_loss, temp_loss def target_soft_update(self, critic, target_critic, tau): with torch.no_grad(): for critic_param, target_critic_param in zip( critic.parameters(), target_critic.parameters()): target_critic_param.data = ( tau * critic_param.data + (1 - tau) * target_critic_param.data) def load(self, model_dir): model_dir = pathlib.Path(model_dir) map_location = None if self.device and self.device.type == "cpu": map_location = "cpu" self.sac_net.actor.load_state_dict( torch.load(model_dir / "actor.pth", map_location=map_location)) self.sac_net.target.load_state_dict( torch.load(model_dir / "target.pth", map_location=map_location)) self.sac_net.critic.load_state_dict( torch.load(model_dir / "critic.pth", map_location=map_location)) print("<<<<<<< MODEL LOADED >>>>>>>>>", model_dir) def save(self, model_dir): model_dir = pathlib.Path(model_dir) # with open(model_dir / "params.yaml", "w") as file: # yaml.dump(policy_params, file) torch.save(self.sac_net.actor.state_dict(), model_dir / "actor.pth") torch.save(self.sac_net.target.state_dict(), model_dir / "target.pth") torch.save(self.sac_net.critic.state_dict(), model_dir / "critic.pth") print("<<<<<<< MODEL SAVED >>>>>>>>>", model_dir) def reset(self): pass
class TD3Policy(Agent): def __init__( self, policy_params=None, checkpoint_dir=None, ): self.policy_params = policy_params self.action_size = 2 self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32) self.actor_lr = float(policy_params["actor_lr"]) self.critic_lr = float(policy_params["critic_lr"]) self.critic_wd = float(policy_params["critic_wd"]) self.actor_wd = float(policy_params["actor_wd"]) self.noise_clip = float(policy_params["noise_clip"]) self.policy_noise = float(policy_params["policy_noise"]) self.update_rate = int(policy_params["update_rate"]) self.policy_delay = int(policy_params["policy_delay"]) self.warmup = int(policy_params["warmup"]) self.critic_tau = float(policy_params["critic_tau"]) self.actor_tau = float(policy_params["actor_tau"]) self.gamma = float(policy_params["gamma"]) self.batch_size = int(policy_params["batch_size"]) self.sigma = float(policy_params["sigma"]) self.theta = float(policy_params["theta"]) self.dt = float(policy_params["dt"]) self.action_low = torch.tensor([[each[0] for each in self.action_range]]) self.action_high = torch.tensor([[each[1] for each in self.action_range]]) self.seed = int(policy_params["seed"]) self.prev_action = np.zeros(self.action_size) self.action_type = adapters.type_from_string(policy_params["action_type"]) self.observation_type = adapters.type_from_string( policy_params["observation_type"] ) self.reward_type = adapters.type_from_string(policy_params["reward_type"]) if self.action_type != adapters.AdapterType.DefaultActionContinuous: raise Exception( f"TD3 baseline only supports the " f"{adapters.AdapterType.DefaultActionContinuous} action type." ) if self.observation_type == adapters.AdapterType.DefaultObservationVector: observation_space = adapters.space_from_type(self.observation_type) low_dim_states_size = observation_space["low_dim_states"].shape[0] social_capacity = observation_space["social_vehicles"].shape[0] num_social_features = observation_space["social_vehicles"].shape[1] # Get information to build the encoder. encoder_key = policy_params["social_vehicles"]["encoder_key"] social_policy_hidden_units = int( policy_params["social_vehicles"].get("social_policy_hidden_units", 0) ) social_policy_init_std = int( policy_params["social_vehicles"].get("social_policy_init_std", 0) ) social_vehicle_config = get_social_vehicle_configs( encoder_key=encoder_key, num_social_features=num_social_features, social_capacity=social_capacity, seed=self.seed, social_policy_hidden_units=social_policy_hidden_units, social_policy_init_std=social_policy_init_std, ) social_vehicle_encoder = social_vehicle_config["encoder"] social_feature_encoder_class = social_vehicle_encoder[ "social_feature_encoder_class" ] social_feature_encoder_params = social_vehicle_encoder[ "social_feature_encoder_params" ] # Calculate the state size based on the number of features (ego + social). state_size = low_dim_states_size if social_feature_encoder_class: state_size += social_feature_encoder_class( **social_feature_encoder_params ).output_dim else: state_size += social_capacity * num_social_features # Add the action size to account for the previous action. state_size += self.action_size actor_network_class = FCActorNetwork critic_network_class = FCCrtiicNetwork network_params = { "state_space": state_size, "action_space": self.action_size, "seed": self.seed, "social_feature_encoder": social_feature_encoder_class( **social_feature_encoder_params ) if social_feature_encoder_class else None, } elif self.observation_type == adapters.AdapterType.DefaultObservationImage: observation_space = adapters.space_from_type(self.observation_type) stack_size = observation_space.shape[0] image_shape = (observation_space.shape[1], observation_space.shape[2]) actor_network_class = CNNActorNetwork critic_network_class = CNNCriticNetwork network_params = { "input_channels": stack_size, "input_dimension": image_shape, "action_size": self.action_size, "seed": self.seed, } else: raise Exception( f"TD3 baseline does not support the '{self.observation_type}' " f"observation type." ) # others self.checkpoint_dir = checkpoint_dir self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.device_name) self.save_codes = ( policy_params["save_codes"] if "save_codes" in policy_params else None ) self.memory = ReplayBuffer( buffer_size=int(policy_params["replay_buffer"]["buffer_size"]), batch_size=int(policy_params["replay_buffer"]["batch_size"]), observation_type=self.observation_type, device_name=self.device_name, ) self.num_actor_updates = 0 self.current_iteration = 0 self.step_count = 0 self.init_networks(actor_network_class, critic_network_class, network_params) if checkpoint_dir: self.load(checkpoint_dir) def init_networks(self, actor_network_class, critic_network_class, network_params): self.noise = [ OrnsteinUhlenbeckProcess( size=(1,), theta=0.01, std=LinearSchedule(0.25), mu=0.0, x0=0.0, dt=1.0 ), # throttle OrnsteinUhlenbeckProcess( size=(1,), theta=0.1, std=LinearSchedule(0.05), mu=0.0, x0=0.0, dt=1.0 ), # steering ] self.actor = actor_network_class(**network_params).to(self.device) self.actor_target = actor_network_class(**network_params).to(self.device) self.actor_target.load_state_dict(self.actor.state_dict()) self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.actor_lr) self.critic_1 = critic_network_class(**network_params).to(self.device) self.critic_1_target = critic_network_class(**network_params).to(self.device) self.critic_1_target.load_state_dict(self.critic_1.state_dict()) self.critic_1_optimizer = optim.Adam( self.critic_1.parameters(), lr=self.critic_lr ) self.critic_2 = critic_network_class(**network_params).to(self.device) self.critic_2_target = critic_network_class(**network_params).to(self.device) self.critic_2_target.load_state_dict(self.critic_2.state_dict()) self.critic_2_optimizer = optim.Adam( self.critic_2.parameters(), lr=self.critic_lr ) def act(self, state, explore=True): state = copy.deepcopy(state) if self.observation_type == adapters.AdapterType.DefaultObservationVector: # Default vector observation type. state["low_dim_states"] = np.float32( np.append(state["low_dim_states"], self.prev_action) ) state["social_vehicles"] = ( torch.from_numpy(state["social_vehicles"]).unsqueeze(0).to(self.device) ) state["low_dim_states"] = ( torch.from_numpy(state["low_dim_states"]).unsqueeze(0).to(self.device) ) else: # Default image observation type. state = torch.from_numpy(state).unsqueeze(0).to(self.device) self.actor.eval() action = self.actor(state).cpu().data.numpy().flatten() noise = [self.noise[0].sample(), self.noise[1].sample()] if explore: action[0] += noise[0] action[1] += noise[1] self.actor.train() action_low, action_high = ( self.action_low.data.cpu().numpy(), self.action_high.data.cpu().numpy(), ) action = np.clip(action, action_low, action_high)[0] return to_3d_action(action) def step(self, state, action, reward, next_state, done, info): # dont treat timeout as done equal to True max_steps_reached = info["logs"]["events"].reached_max_episode_steps reset_noise = False if max_steps_reached: done = False reset_noise = True output = {} action = to_2d_action(action) self.memory.add( state=state, action=action, reward=reward, next_state=next_state, done=float(done), prev_action=self.prev_action, ) self.step_count += 1 if reset_noise: self.reset() if ( len(self.memory) > max(self.batch_size, self.warmup) and (self.step_count + 1) % self.update_rate == 0 ): output = self.learn() self.prev_action = action if not done else np.zeros(self.action_size) return output def reset(self): self.noise[0].reset_states() self.noise[1].reset_states() def learn(self): output = {} states, actions, rewards, next_states, dones, others = self.memory.sample( device=self.device ) actions = actions.squeeze(dim=1) next_actions = self.actor_target(next_states) noise = torch.randn_like(next_actions).mul(self.policy_noise) noise = noise.clamp(-self.noise_clip, self.noise_clip) next_actions += noise next_actions = torch.max( torch.min(next_actions, self.action_high.to(self.device)), self.action_low.to(self.device), ) target_Q1 = self.critic_1_target(next_states, next_actions) target_Q2 = self.critic_2_target(next_states, next_actions) target_Q = torch.min(target_Q1, target_Q2) target_Q = (rewards + ((1 - dones) * self.gamma * target_Q)).detach() # Optimize Critic 1: current_Q1, aux_losses_Q1 = self.critic_1(states, actions, training=True) loss_Q1 = F.mse_loss(current_Q1, target_Q) + compute_sum_aux_losses( aux_losses_Q1 ) self.critic_1_optimizer.zero_grad() loss_Q1.backward() self.critic_1_optimizer.step() # Optimize Critic 2: current_Q2, aux_losses_Q2 = self.critic_2(states, actions, training=True) loss_Q2 = F.mse_loss(current_Q2, target_Q) + compute_sum_aux_losses( aux_losses_Q2 ) self.critic_2_optimizer.zero_grad() loss_Q2.backward() self.critic_2_optimizer.step() # delayed actor updates if (self.step_count + 1) % self.policy_delay == 0: critic_out = self.critic_1(states, self.actor(states), training=True) actor_loss, actor_aux_losses = -critic_out[0], critic_out[1] actor_loss = actor_loss.mean() + compute_sum_aux_losses(actor_aux_losses) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step() self.soft_update(self.actor_target, self.actor, self.actor_tau) self.num_actor_updates += 1 output = { "loss/critic_1": { "type": "scalar", "data": loss_Q1.data.cpu().numpy(), "freq": 10, }, "loss/actor": { "type": "scalar", "data": actor_loss.data.cpu().numpy(), "freq": 10, }, } self.soft_update(self.critic_1_target, self.critic_1, self.critic_tau) self.soft_update(self.critic_2_target, self.critic_2, self.critic_tau) self.current_iteration += 1 return output def soft_update(self, target, src, tau): for target_param, param in zip(target.parameters(), src.parameters()): target_param.detach_() target_param.copy_(target_param * (1.0 - tau) + param * tau) def load(self, model_dir): model_dir = pathlib.Path(model_dir) map_location = None if self.device and self.device.type == "cpu": map_location = "cpu" self.actor.load_state_dict( torch.load(model_dir / "actor.pth", map_location=map_location) ) self.actor_target.load_state_dict( torch.load(model_dir / "actor_target.pth", map_location=map_location) ) self.critic_1.load_state_dict( torch.load(model_dir / "critic_1.pth", map_location=map_location) ) self.critic_1_target.load_state_dict( torch.load(model_dir / "critic_1_target.pth", map_location=map_location) ) self.critic_2.load_state_dict( torch.load(model_dir / "critic_2.pth", map_location=map_location) ) self.critic_2_target.load_state_dict( torch.load(model_dir / "critic_2_target.pth", map_location=map_location) ) def save(self, model_dir): model_dir = pathlib.Path(model_dir) torch.save(self.actor.state_dict(), model_dir / "actor.pth") torch.save( self.actor_target.state_dict(), model_dir / "actor_target.pth", ) torch.save(self.critic_1.state_dict(), model_dir / "critic_1.pth") torch.save( self.critic_1_target.state_dict(), model_dir / "critic_1_target.pth", ) torch.save(self.critic_2.state_dict(), model_dir / "critic_2.pth") torch.save( self.critic_2_target.state_dict(), model_dir / "critic_2_target.pth", )