Ejemplo n.º 1
0
class TD3Policy(Agent):
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.action_size = int(policy_params["action_size"])
        self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]],
                                       dtype=np.float32)
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.critic_wd = float(policy_params["critic_wd"])
        self.actor_wd = float(policy_params["actor_wd"])
        self.noise_clip = float(policy_params["noise_clip"])
        self.policy_noise = float(policy_params["policy_noise"])
        self.update_rate = int(policy_params["update_rate"])
        self.policy_delay = int(policy_params["policy_delay"])
        self.warmup = int(policy_params["warmup"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.sigma = float(policy_params["sigma"])
        self.theta = float(policy_params["theta"])
        self.dt = float(policy_params["dt"])
        self.action_low = torch.tensor(
            [[each[0] for each in self.action_range]])
        self.action_high = torch.tensor(
            [[each[1] for each in self.action_range]])
        self.seed = int(policy_params["seed"])
        self.prev_action = np.zeros(self.action_size)

        # state preprocessing
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_capacity = int(policy_params["social_vehicles"].get(
            "social_capacity", 0))
        self.observation_num_lookahead = int(
            policy_params.get("observation_num_lookahead", 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.num_social_features = int(policy_params["social_vehicles"].get(
            "num_social_features", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            **policy_params["social_vehicles"])

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.state_description = BaselineStatePreprocessor.get_state_description(
            policy_params["social_vehicles"],
            policy_params["observation_num_lookahead"],
            self.action_size,
        )

        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            device_name=self.device_name,
        )
        self.num_actor_updates = 0
        self.current_iteration = 0
        self.step_count = 0
        self.init_networks()
        if checkpoint_dir:
            self.load(checkpoint_dir)

    @property
    def state_size(self):
        # Adjusting state_size based on number of features (ego+social)
        size = sum(self.state_description["low_dim_states"].values())
        if self.social_feature_encoder_class:
            size += self.social_feature_encoder_class(
                **self.social_feature_encoder_params).output_dim
        else:
            size += self.social_capacity * self.num_social_features
        # adding the previous action
        size += self.action_size
        return size

    def init_networks(self):
        self.noise = [
            OrnsteinUhlenbeckProcess(size=(1, ),
                                     theta=0.01,
                                     std=LinearSchedule(0.25),
                                     mu=0.0,
                                     x0=0.0,
                                     dt=1.0),  # throttle
            OrnsteinUhlenbeckProcess(size=(1, ),
                                     theta=0.1,
                                     std=LinearSchedule(0.05),
                                     mu=0.0,
                                     x0=0.0,
                                     dt=1.0),  # steering
        ]
        self.actor = ActorNetwork(
            self.state_size,
            self.action_size,
            self.seed,
            social_feature_encoder=self.social_feature_encoder_class(
                **self.social_feature_encoder_params)
            if self.social_feature_encoder_class else None,
        ).to(self.device)
        self.actor_target = ActorNetwork(
            self.state_size,
            self.action_size,
            self.seed,
            social_feature_encoder=self.social_feature_encoder_class(
                **self.social_feature_encoder_params)
            if self.social_feature_encoder_class else None,
        ).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(),
                                          lr=self.actor_lr)

        self.critic_1 = CriticNetwork(
            self.state_size,
            self.action_size,
            self.seed,
            social_feature_encoder=self.social_feature_encoder_class(
                **self.social_feature_encoder_params)
            if self.social_feature_encoder_class else None,
        ).to(self.device)
        self.critic_1_target = CriticNetwork(
            self.state_size,
            self.action_size,
            self.seed,
            social_feature_encoder=self.social_feature_encoder_class(
                **self.social_feature_encoder_params)
            if self.social_feature_encoder_class else None,
        ).to(self.device)
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_1_optimizer = optim.Adam(self.critic_1.parameters(),
                                             lr=self.critic_lr)

        self.critic_2 = CriticNetwork(
            self.state_size,
            self.action_size,
            self.seed,
            social_feature_encoder=self.social_feature_encoder_class(
                **self.social_feature_encoder_params)
            if self.social_feature_encoder_class else None,
        ).to(self.device)
        self.critic_2_target = CriticNetwork(
            self.state_size,
            self.action_size,
            self.seed,
            social_feature_encoder=self.social_feature_encoder_class(
                **self.social_feature_encoder_params)
            if self.social_feature_encoder_class else None,
        ).to(self.device)
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())
        self.critic_2_optimizer = optim.Adam(self.critic_2.parameters(),
                                             lr=self.critic_lr)

    def act(self, state, explore=True):
        state = copy.deepcopy(state)
        state["low_dim_states"] = np.float32(
            np.append(state["low_dim_states"], self.prev_action))
        state["social_vehicles"] = (torch.from_numpy(
            state["social_vehicles"]).unsqueeze(0).to(self.device))
        state["low_dim_states"] = (torch.from_numpy(
            state["low_dim_states"]).unsqueeze(0).to(self.device))

        self.actor.eval()

        action = self.actor(state).cpu().data.numpy().flatten()

        noise = [self.noise[0].sample(), self.noise[1].sample()]
        if explore:
            action[0] += noise[0]
            action[1] += noise[1]

        self.actor.train()
        action_low, action_high = (
            self.action_low.data.cpu().numpy(),
            self.action_high.data.cpu().numpy(),
        )
        action = np.clip(action, action_low, action_high)[0]

        return to_3d_action(action)

    def step(self, state, action, reward, next_state, done, info):
        # dont treat timeout as done equal to True
        max_steps_reached = info["logs"]["events"].reached_max_episode_steps
        reset_noise = False
        if max_steps_reached:
            done = False
            reset_noise = True

        output = {}
        action = to_2d_action(action)
        self.memory.add(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=float(done),
            social_capacity=self.social_capacity,
            observation_num_lookahead=self.observation_num_lookahead,
            social_vehicle_config=self.social_vehicle_config,
            prev_action=self.prev_action,
        )
        self.step_count += 1
        if reset_noise:
            self.reset()
        if (len(self.memory) > max(self.batch_size, self.warmup)
                and (self.step_count + 1) % self.update_rate == 0):
            output = self.learn()

        self.prev_action = action if not done else np.zeros(self.action_size)
        return output

    def reset(self):
        self.noise[0].reset_states()
        self.noise[1].reset_states()

    def learn(self):
        output = {}
        states, actions, rewards, next_states, dones, others = self.memory.sample(
            device=self.device)
        actions = actions.squeeze(dim=1)
        next_actions = self.actor_target(next_states)
        noise = torch.randn_like(next_actions).mul(self.policy_noise)
        noise = noise.clamp(-self.noise_clip, self.noise_clip)
        next_actions += noise
        next_actions = torch.max(
            torch.min(next_actions, self.action_high.to(self.device)),
            self.action_low.to(self.device),
        )

        target_Q1 = self.critic_1_target(next_states, next_actions)
        target_Q2 = self.critic_2_target(next_states, next_actions)
        target_Q = torch.min(target_Q1, target_Q2)
        target_Q = (rewards + ((1 - dones) * self.gamma * target_Q)).detach()

        # Optimize Critic 1:
        current_Q1, aux_losses_Q1 = self.critic_1(states,
                                                  actions,
                                                  training=True)
        loss_Q1 = F.mse_loss(current_Q1,
                             target_Q) + compute_sum_aux_losses(aux_losses_Q1)
        self.critic_1_optimizer.zero_grad()
        loss_Q1.backward()
        self.critic_1_optimizer.step()

        # Optimize Critic 2:
        current_Q2, aux_losses_Q2 = self.critic_2(states,
                                                  actions,
                                                  training=True)
        loss_Q2 = F.mse_loss(current_Q2,
                             target_Q) + compute_sum_aux_losses(aux_losses_Q2)
        self.critic_2_optimizer.zero_grad()
        loss_Q2.backward()
        self.critic_2_optimizer.step()

        # delayed actor updates
        if (self.step_count + 1) % self.policy_delay == 0:
            critic_out = self.critic_1(states,
                                       self.actor(states),
                                       training=True)
            actor_loss, actor_aux_losses = -critic_out[0], critic_out[1]
            actor_loss = actor_loss.mean() + compute_sum_aux_losses(
                actor_aux_losses)
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            self.soft_update(self.actor_target, self.actor, self.actor_tau)

            self.num_actor_updates += 1
            output = {
                "loss/critic_1": {
                    "type": "scalar",
                    "data": loss_Q1.data.cpu().numpy(),
                    "freq": 10,
                },
                "loss/actor": {
                    "type": "scalar",
                    "data": actor_loss.data.cpu().numpy(),
                    "freq": 10,
                },
            }
        self.soft_update(self.critic_1_target, self.critic_1, self.critic_tau)
        self.soft_update(self.critic_2_target, self.critic_2, self.critic_tau)
        self.current_iteration += 1
        return output

    def soft_update(self, target, src, tau):
        for target_param, param in zip(target.parameters(), src.parameters()):
            target_param.detach_()
            target_param.copy_(target_param * (1.0 - tau) + param * tau)

    def load(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        map_location = None
        if self.device and self.device.type == "cpu":
            map_location = "cpu"
        self.actor.load_state_dict(
            torch.load(model_dir / "actor.pth", map_location=map_location))
        self.actor_target.load_state_dict(
            torch.load(model_dir / "actor_target.pth",
                       map_location=map_location))
        self.critic_1.load_state_dict(
            torch.load(model_dir / "critic_1.pth", map_location=map_location))
        self.critic_1_target.load_state_dict(
            torch.load(model_dir / "critic_1_target.pth",
                       map_location=map_location))
        self.critic_2.load_state_dict(
            torch.load(model_dir / "critic_2.pth", map_location=map_location))
        self.critic_2_target.load_state_dict(
            torch.load(model_dir / "critic_2_target.pth",
                       map_location=map_location))

    def save(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        torch.save(self.actor.state_dict(), model_dir / "actor.pth")
        torch.save(
            self.actor_target.state_dict(),
            model_dir / "actor_target.pth",
        )
        torch.save(self.critic_1.state_dict(), model_dir / "critic_1.pth")
        torch.save(
            self.critic_1_target.state_dict(),
            model_dir / "critic_1_target.pth",
        )
        torch.save(self.critic_2.state_dict(), model_dir / "critic_2.pth")
        torch.save(
            self.critic_2_target.state_dict(),
            model_dir / "critic_2_target.pth",
        )
Ejemplo n.º 2
0
class DQNPolicy(Agent):
    lane_actions = [
        "keep_lane", "slow_down", "change_lane_left", "change_lane_right"
    ]

    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        network_class = DQNWithSocialEncoder
        self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000)
        action_space_type = policy_params["action_space_type"]
        if action_space_type == "continuous":
            discrete_action_spaces = [
                np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]),
                np.asarray([
                    -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
                    1.0
                ]),
            ]
        else:
            discrete_action_spaces = [[0], [1]]
        action_size = discrete_action_spaces
        self.merge_action_spaces = 0 if action_space_type == "continuous" else -1

        self.step_count = 0
        self.update_count = 0
        self.num_updates = 0
        self.current_sticky = 0
        self.current_iteration = 0

        lr = float(policy_params["lr"])
        seed = int(policy_params["seed"])
        self.train_step = int(policy_params["train_step"])
        self.target_update = float(policy_params["target_update"])
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.warmup = int(policy_params["warmup"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.use_ddqn = policy_params["use_ddqn"]
        self.sticky_actions = int(policy_params["sticky_actions"])
        prev_action_size = int(policy_params["prev_action_size"])
        self.prev_action = np.zeros(prev_action_size)

        if self.merge_action_spaces == 1:
            index2action, action2index = merge_discrete_action_spaces(
                *action_size)
            self.index2actions = [index2action]
            self.action2indexs = [action2index]
            self.num_actions = [len(self.index2actions)]
        elif self.merge_action_spaces == 0:
            self.index2actions = [
                merge_discrete_action_spaces([each])[0] for each in action_size
            ]
            self.action2indexs = [
                merge_discrete_action_spaces([each])[1] for each in action_size
            ]
            self.num_actions = [len(e) for e in action_size]
        else:
            index_to_actions = [
                e.tolist() if not isinstance(e, list) else e
                for e in action_size
            ]
            action_to_indexs = {
                str(k): v
                for k, v in zip(
                    index_to_actions,
                    np.arange(len(index_to_actions)).astype(np.int))
            }
            self.index2actions, self.action2indexs = (
                [index_to_actions],
                [action_to_indexs],
            )
            self.num_actions = [len(index_to_actions)]

        # state preprocessing
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_capacity = int(policy_params["social_vehicles"].get(
            "social_capacity", 0))
        self.observation_num_lookahead = int(
            policy_params.get("observation_num_lookahead", 0))
        self.social_polciy_init_std = int(policy_params["social_vehicles"].get(
            "social_polciy_init_std", 0))
        self.num_social_features = int(policy_params["social_vehicles"].get(
            "num_social_features", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            **policy_params["social_vehicles"])

        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.state_description = get_state_description(
            policy_params["social_vehicles"],
            policy_params["observation_num_lookahead"],
            prev_action_size,
        )
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        self.checkpoint_dir = checkpoint_dir
        self.reset()

        torch.manual_seed(seed)
        network_params = {
            "state_size": self.state_size,
            "social_feature_encoder_class": self.social_feature_encoder_class,
            "social_feature_encoder_params":
            self.social_feature_encoder_params,
        }
        self.online_q_network = network_class(
            num_actions=self.num_actions,
            **(network_params if network_params else {}),
        ).to(self.device)
        self.target_q_network = network_class(
            num_actions=self.num_actions,
            **(network_params if network_params else {}),
        ).to(self.device)
        self.update_target_network()

        self.optimizers = torch.optim.Adam(
            params=self.online_q_network.parameters(), lr=lr)
        self.loss_func = nn.MSELoss(reduction="none")

        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)

        self.action_space_type = "continuous"
        self.to_real_action = to_3d_action
        self.state_preprocessor = StatePreprocessor(preprocess_state,
                                                    to_2d_action,
                                                    self.state_description)
        self.replay = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            state_preprocessor=self.state_preprocessor,
            device_name=self.device_name,
        )

    def lane_action_to_index(self, state):
        state = state.copy()
        if (len(state["action"]) == 3 and (state["action"] == np.asarray(
            [0, 0, 0])).all()):  # initial action
            state["action"] = np.asarray([0])
        else:
            state["action"] = self.lane_actions.index(state["action"])
        return state

    @property
    def state_size(self):
        # Adjusting state_size based on number of features (ego+social)
        size = sum(self.state_description["low_dim_states"].values())
        if self.social_feature_encoder_class:
            size += self.social_feature_encoder_class(
                **self.social_feature_encoder_params).output_dim
        else:
            size += self.social_capacity * self.num_social_features
        return size

    def reset(self):
        self.eps_throttles = []
        self.eps_steers = []
        self.eps_step = 0
        self.current_sticky = 0

    def soft_update(self, target, src, tau):
        for target_param, param in zip(target.parameters(), src.parameters()):
            target_param.detach_()
            target_param.copy_(target_param * (1.0 - tau) + param * tau)

    def update_target_network(self):
        self.target_q_network.load_state_dict(
            self.online_q_network.state_dict().copy())

    def act(self, *args, **kwargs):
        if self.current_sticky == 0:
            self.action = self._act(*args, **kwargs)
        self.current_sticky = (self.current_sticky + 1) % self.sticky_actions
        self.current_iteration += 1
        return self.to_real_action(self.action)

    def _act(self, state, explore=True):
        epsilon = self.epsilon_obj.get_epsilon()
        if not explore or np.random.rand() > epsilon:
            state = self.state_preprocessor(
                state,
                normalize=True,
                unsqueeze=True,
                device=self.device,
                social_capacity=self.social_capacity,
                observation_num_lookahead=self.observation_num_lookahead,
                social_vehicle_config=self.social_vehicle_config,
                prev_action=self.prev_action,
            )
            self.online_q_network.eval()
            with torch.no_grad():
                qs = self.online_q_network(state)
            qs = [q.data.cpu().numpy().flatten() for q in qs]
            # out_str = " || ".join(
            #     [
            #         " ".join(
            #             [
            #                 "{}: {:.4f}".format(index2action[j], q[j])
            #                 for j in range(num_action)
            #             ]
            #         )
            #         for index2action, q, num_action in zip(
            #             self.index2actions, qs, self.num_actions
            #         )
            #     ]
            # )
            # print(out_str)
            inds = [np.argmax(q) for q in qs]
        else:
            inds = [
                np.random.randint(num_action)
                for num_action in self.num_actions
            ]
        action = []
        for j, ind in enumerate(inds):
            action.extend(self.index2actions[j][ind])
        self.epsilon_obj.step()
        self.eps_step += 1
        action = np.asarray(action)
        return action

    def save(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        torch.save(self.online_q_network.state_dict(),
                   model_dir / "online.pth")
        torch.save(self.target_q_network.state_dict(),
                   model_dir / "target.pth")

    def load(self, model_dir, cpu=False):
        model_dir = pathlib.Path(model_dir)
        print("loading from :", model_dir)

        map_location = None
        if cpu:
            map_location = torch.device("cpu")
        self.online_q_network.load_state_dict(
            torch.load(model_dir / "online.pth", map_location=map_location))
        self.target_q_network.load_state_dict(
            torch.load(model_dir / "target.pth", map_location=map_location))
        print("Model loaded")

    def step(self, state, action, reward, next_state, done, others=None):
        # dont treat timeout as done equal to True
        max_steps_reached = state["events"].reached_max_episode_steps
        if max_steps_reached:
            done = False
        if self.action_space_type == "continuous":
            action = to_2d_action(action)
            _action = ([[e] for e in action]
                       if not self.merge_action_spaces else [action.tolist()])
            action_index = np.asarray([
                action2index[str(e)]
                for action2index, e in zip(self.action2indexs, _action)
            ])
        else:
            action_index = self.lane_actions.index(action)
            action = action_index
        self.replay.add(
            state=state,
            action=action_index,
            reward=reward,
            next_state=next_state,
            done=done,
            others=others,
            social_capacity=self.social_capacity,
            observation_num_lookahead=self.observation_num_lookahead,
            social_vehicle_config=self.social_vehicle_config,
            prev_action=self.prev_action,
        )
        if (self.step_count % self.train_step == 0
                and len(self.replay) >= self.batch_size
                and (self.warmup is None or len(self.replay) >= self.warmup)):
            out = self.learn()
            self.update_count += 1
        else:
            out = {}

        if self.target_update > 1 and self.step_count % self.target_update == 0:
            self.update_target_network()
        elif self.target_update < 1.0:
            self.soft_update(self.target_q_network, self.online_q_network,
                             self.target_update)
        self.step_count += 1
        self.prev_action = action

        return out

    def learn(self):
        states, actions, rewards, next_states, dones, others = self.replay.sample(
            device=self.device)
        if not self.merge_action_spaces:
            actions = torch.chunk(actions, len(self.num_actions), -1)
        else:
            actions = [actions]

        self.target_q_network.eval()
        with torch.no_grad():
            qs_next_target = self.target_q_network(next_states)

        if self.use_ddqn:
            self.online_q_network.eval()
            with torch.no_grad():
                qs_next_online = self.online_q_network(next_states)
            next_actions = [
                torch.argmax(q_next_online, dim=1, keepdim=True)
                for q_next_online in qs_next_online
            ]
        else:
            next_actions = [
                torch.argmax(q_next_target, dim=1, keepdim=True)
                for q_next_target in qs_next_target
            ]

        qs_next_target = [
            torch.gather(q_next_target, 1, next_action)
            for q_next_target, next_action in zip(qs_next_target, next_actions)
        ]

        self.online_q_network.train()
        qs, aux_losses = self.online_q_network(states, training=True)
        qs = [
            torch.gather(q, 1, action.long())
            for q, action in zip(qs, actions)
        ]
        qs_target_value = [
            rewards + self.gamma * (1 - dones) * q_next_target
            for q_next_target in qs_next_target
        ]
        td_loss = [
            self.loss_func(q, q_target_value).mean()
            for q, q_target_value in zip(qs, qs_target_value)
        ]
        mean_td_loss = sum(td_loss) / len(td_loss)

        loss = mean_td_loss + sum(
            [e["value"] * e["weight"] for e in aux_losses.values()])

        self.optimizers.zero_grad()
        loss.backward()
        self.optimizers.step()

        out = {}
        out.update({
            "loss/td{}".format(j): {
                "type": "scalar",
                "data": td_loss[j].data.cpu().numpy(),
                "freq": 10,
            }
            for j in range(len(td_loss))
        })
        out.update({
            "loss/{}".format(k): {
                "type": "scalar",
                "data": v["value"],  # .detach().cpu().numpy(),
                "freq": 10,
            }
            for k, v in aux_losses.items()
        })
        out.update({"loss/all": {"type": "scalar", "data": loss, "freq": 10}})

        self.num_updates += 1
        return out
Ejemplo n.º 3
0
class DQNPolicy(Agent):
    lane_actions = [
        "keep_lane", "slow_down", "change_lane_left", "change_lane_right"
    ]

    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.lr = float(policy_params["lr"])
        self.seed = int(policy_params["seed"])
        self.train_step = int(policy_params["train_step"])
        self.target_update = float(policy_params["target_update"])
        self.warmup = int(policy_params["warmup"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.use_ddqn = policy_params["use_ddqn"]
        self.sticky_actions = int(policy_params["sticky_actions"])
        self.epsilon_obj = EpsilonExplore(1.0, 0.05, 100000)
        self.step_count = 0
        self.update_count = 0
        self.num_updates = 0
        self.current_sticky = 0
        self.current_iteration = 0
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type == adapters.AdapterType.DefaultActionContinuous:
            discrete_action_spaces = [
                np.asarray([-0.25, 0.0, 0.5, 0.75, 1.0]),
                np.asarray([
                    -1.0, -0.75, -0.5, -0.25, -0.1, 0.0, 0.1, 0.25, 0.5, 0.75,
                    1.0
                ]),
            ]
            self.index2actions = [
                merge_discrete_action_spaces([discrete_action_space])[0]
                for discrete_action_space in discrete_action_spaces
            ]
            self.action2indexs = [
                merge_discrete_action_spaces([discrete_action_space])[1]
                for discrete_action_space in discrete_action_spaces
            ]
            self.merge_action_spaces = 0
            self.num_actions = [
                len(discrete_action_space)
                for discrete_action_space in discrete_action_spaces
            ]
            self.action_size = 2
            self.to_real_action = to_3d_action
        elif self.action_type == adapters.AdapterType.DefaultActionDiscrete:
            discrete_action_spaces = [[0], [1], [2], [3]]
            index_to_actions = [
                discrete_action_space.tolist()
                if not isinstance(discrete_action_space, list) else
                discrete_action_space
                for discrete_action_space in discrete_action_spaces
            ]
            action_to_indexs = {
                str(discrete_action): index
                for discrete_action, index in zip(
                    index_to_actions,
                    np.arange(len(index_to_actions)).astype(np.int))
            }
            self.index2actions = [index_to_actions]
            self.action2indexs = [action_to_indexs]
            self.merge_action_spaces = -1
            self.num_actions = [len(index_to_actions)]
            self.action_size = 1
            self.to_real_action = lambda action: self.lane_actions[action[0]]
        else:
            raise Exception(
                f"DQN baseline does not support the '{self.action_type}' action type."
            )

        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            observation_space = adapters.space_from_type(self.observation_type)
            low_dim_states_size = observation_space["low_dim_states"].shape[0]
            social_capacity = observation_space["social_vehicles"].shape[0]
            num_social_features = observation_space["social_vehicles"].shape[1]

            # Get information to build the encoder.
            encoder_key = policy_params["social_vehicles"]["encoder_key"]
            social_policy_hidden_units = int(
                policy_params["social_vehicles"].get(
                    "social_policy_hidden_units", 0))
            social_policy_init_std = int(policy_params["social_vehicles"].get(
                "social_policy_init_std", 0))
            social_vehicle_config = get_social_vehicle_configs(
                encoder_key=encoder_key,
                num_social_features=num_social_features,
                social_capacity=social_capacity,
                seed=self.seed,
                social_policy_hidden_units=social_policy_hidden_units,
                social_policy_init_std=social_policy_init_std,
            )
            social_vehicle_encoder = social_vehicle_config["encoder"]
            social_feature_encoder_class = social_vehicle_encoder[
                "social_feature_encoder_class"]
            social_feature_encoder_params = social_vehicle_encoder[
                "social_feature_encoder_params"]

            # Calculate the state size based on the number of features (ego + social).
            state_size = low_dim_states_size
            if social_feature_encoder_class:
                state_size += social_feature_encoder_class(
                    **social_feature_encoder_params).output_dim
            else:
                state_size += social_capacity * num_social_features
            # Add the action size to account for the previous action.
            state_size += self.action_size

            network_class = DQNWithSocialEncoder
            network_params = {
                "num_actions": self.num_actions,
                "state_size": state_size,
                "social_feature_encoder_class": social_feature_encoder_class,
                "social_feature_encoder_params": social_feature_encoder_params,
            }
        elif self.observation_type == adapters.AdapterType.DefaultObservationImage:
            observation_space = adapters.space_from_type(self.observation_type)
            stack_size = observation_space.shape[0]
            image_shape = (observation_space.shape[1],
                           observation_space.shape[2])

            network_class = DQNCNN
            network_params = {
                "n_in_channels": stack_size,
                "image_dim": image_shape,
                "num_actions": self.num_actions,
            }
        else:
            raise Exception(
                f"DQN baseline does not support the '{self.observation_type}' "
                f"observation type.")

        self.prev_action = np.zeros(self.action_size)
        self.checkpoint_dir = checkpoint_dir
        torch.manual_seed(self.seed)
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.online_q_network = network_class(**network_params).to(self.device)
        self.target_q_network = network_class(**network_params).to(self.device)
        self.update_target_network()
        self.optimizers = torch.optim.Adam(
            params=self.online_q_network.parameters(), lr=self.lr)
        self.loss_func = nn.MSELoss(reduction="none")
        self.replay = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.reset()
        if self.checkpoint_dir:
            self.load(self.checkpoint_dir)

    def lane_action_to_index(self, state):
        state = state.copy()
        if (len(state["action"]) == 3 and (state["action"] == np.asarray(
            [0, 0, 0])).all()):  # initial action
            state["action"] = np.asarray([0])
        else:
            state["action"] = self.lane_actions.index(state["action"])
        return state

    def reset(self):
        self.eps_throttles = []
        self.eps_steers = []
        self.eps_step = 0
        self.current_sticky = 0

    def soft_update(self, target, src, tau):
        for target_param, param in zip(target.parameters(), src.parameters()):
            target_param.detach_()
            target_param.copy_(target_param * (1.0 - tau) + param * tau)

    def update_target_network(self):
        self.target_q_network.load_state_dict(
            self.online_q_network.state_dict().copy())

    def act(self, *args, **kwargs):
        if self.current_sticky == 0:
            self.action = self._act(*args, **kwargs)
        self.current_sticky = (self.current_sticky + 1) % self.sticky_actions
        self.current_iteration += 1
        return self.to_real_action(self.action)

    def _act(self, state, explore=True):
        epsilon = self.epsilon_obj.get_epsilon()
        if not explore or np.random.rand() > epsilon:
            state = copy.deepcopy(state)
            if self.observation_type == adapters.AdapterType.DefaultObservationVector:
                # Default vector observation type.
                state["low_dim_states"] = np.float32(
                    np.append(state["low_dim_states"], self.prev_action))
                state["social_vehicles"] = (torch.from_numpy(
                    state["social_vehicles"]).unsqueeze(0).to(self.device))
                state["low_dim_states"] = (torch.from_numpy(
                    state["low_dim_states"]).unsqueeze(0).to(self.device))
            else:
                # Default image observation type.
                state = torch.from_numpy(state).unsqueeze(0).to(self.device)
            self.online_q_network.eval()
            with torch.no_grad():
                qs = self.online_q_network(state)
            qs = [q.data.cpu().numpy().flatten() for q in qs]
            # out_str = " || ".join(
            #     [
            #         " ".join(
            #             [
            #                 "{}: {:.4f}".format(index2action[j], q[j])
            #                 for j in range(num_action)
            #             ]
            #         )
            #         for index2action, q, num_action in zip(
            #             self.index2actions, qs, self.num_actions
            #         )
            #     ]
            # )
            # print(out_str)
            inds = [np.argmax(q) for q in qs]
        else:
            inds = [
                np.random.randint(num_action)
                for num_action in self.num_actions
            ]
        action = []
        for j, ind in enumerate(inds):
            action.extend(self.index2actions[j][ind])
        self.epsilon_obj.step()
        self.eps_step += 1
        action = np.asarray(action)
        return action

    def save(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        torch.save(self.online_q_network.state_dict(),
                   model_dir / "online.pth")
        torch.save(self.target_q_network.state_dict(),
                   model_dir / "target.pth")

    def load(self, model_dir, cpu=False):
        model_dir = pathlib.Path(model_dir)
        print("loading from :", model_dir)

        map_location = None
        if cpu:
            map_location = torch.device("cpu")
        self.online_q_network.load_state_dict(
            torch.load(model_dir / "online.pth", map_location=map_location))
        self.target_q_network.load_state_dict(
            torch.load(model_dir / "target.pth", map_location=map_location))
        print("Model loaded")

    def step(self, state, action, reward, next_state, done, info, others=None):
        # dont treat timeout as done equal to True
        max_steps_reached = info["logs"]["events"].reached_max_episode_steps
        if max_steps_reached:
            done = False
        if self.action_type == adapters.AdapterType.DefaultActionContinuous:
            action = to_2d_action(action)
            _action = ([[e] for e in action]
                       if not self.merge_action_spaces else [action.tolist()])
            action_index = np.asarray([
                action2index[str(e)]
                for action2index, e in zip(self.action2indexs, _action)
            ])
        else:
            action_index = self.lane_actions.index(action)
            action = action_index
        self.replay.add(
            state=state,
            action=action_index,
            reward=reward,
            next_state=next_state,
            done=done,
            others=others,
            prev_action=self.prev_action,
        )
        if (self.step_count % self.train_step == 0
                and len(self.replay) >= self.batch_size
                and (self.warmup is None or len(self.replay) >= self.warmup)):
            out = self.learn()
            self.update_count += 1
        else:
            out = {}

        if self.target_update > 1 and self.step_count % self.target_update == 0:
            self.update_target_network()
        elif self.target_update < 1.0:
            self.soft_update(self.target_q_network, self.online_q_network,
                             self.target_update)
        self.step_count += 1
        self.prev_action = action

        return out

    def learn(self):
        states, actions, rewards, next_states, dones, others = self.replay.sample(
            device=self.device)
        if not self.merge_action_spaces:
            actions = torch.chunk(actions, len(self.num_actions), -1)
        else:
            actions = [actions]

        self.target_q_network.eval()
        with torch.no_grad():
            qs_next_target = self.target_q_network(next_states)

        if self.use_ddqn:
            self.online_q_network.eval()
            with torch.no_grad():
                qs_next_online = self.online_q_network(next_states)
            next_actions = [
                torch.argmax(q_next_online, dim=1, keepdim=True)
                for q_next_online in qs_next_online
            ]
        else:
            next_actions = [
                torch.argmax(q_next_target, dim=1, keepdim=True)
                for q_next_target in qs_next_target
            ]

        qs_next_target = [
            torch.gather(q_next_target, 1, next_action)
            for q_next_target, next_action in zip(qs_next_target, next_actions)
        ]

        self.online_q_network.train()
        qs, aux_losses = self.online_q_network(states, training=True)
        qs = [
            torch.gather(q, 1, action.long())
            for q, action in zip(qs, actions)
        ]
        qs_target_value = [
            rewards + self.gamma * (1 - dones) * q_next_target
            for q_next_target in qs_next_target
        ]
        td_loss = [
            self.loss_func(q, q_target_value).mean()
            for q, q_target_value in zip(qs, qs_target_value)
        ]
        mean_td_loss = sum(td_loss) / len(td_loss)

        loss = mean_td_loss + sum(
            [e["value"] * e["weight"] for e in aux_losses.values()])

        self.optimizers.zero_grad()
        loss.backward()
        self.optimizers.step()

        out = {}
        out.update({
            "loss/td{}".format(j): {
                "type": "scalar",
                "data": td_loss[j].data.cpu().numpy(),
                "freq": 10,
            }
            for j in range(len(td_loss))
        })
        out.update({
            "loss/{}".format(k): {
                "type": "scalar",
                "data": v["value"],  # .detach().cpu().numpy(),
                "freq": 10,
            }
            for k, v in aux_losses.items()
        })
        out.update({"loss/all": {"type": "scalar", "data": loss, "freq": 10}})

        self.num_updates += 1
        return out
Ejemplo n.º 4
0
    def test_image_replay_buffer(self):
        TRANSITIONS = 1024  # The number of transitions to save in the replay buffer.
        STACK_SIZE = 4  # The stack size of the images.
        ACTION_SIZE = 3  # The size of each action.
        IMAGE_WIDTH = 64  # The width of each image.
        IMAGE_HEIGHT = 64  # The height of each image.
        BUFFER_SIZE = 1024  # The size of the replay buffer.
        BATCH_SIZE = 128  # Batch size of each sample from the replay buffer.
        NUM_SAMPLES = 10  # Number of times to sample from the replay buffer.

        replay_buffer = ReplayBuffer(
            buffer_size=BUFFER_SIZE,
            batch_size=BATCH_SIZE,
            observation_type=adapters.AdapterType.DefaultObservationImage,
            device_name="cpu",
        )

        (
            states,
            next_states,
            previous_actions,
            actions,
            rewards,
            dones,
        ) = generate_image_transitions(TRANSITIONS, STACK_SIZE, IMAGE_HEIGHT,
                                       IMAGE_WIDTH, ACTION_SIZE)

        for state, next_state, action, previous_action, reward, done in zip(
                states, next_states, actions, previous_actions, rewards,
                dones):
            replay_buffer.add(
                state=state,
                next_state=next_state,
                action=action,
                prev_action=previous_action,
                reward=reward,
                done=done,
            )

        for _ in range(NUM_SAMPLES):
            sample = replay_buffer.sample()

            for state, action, reward, next_state, done, _ in zip(*sample):
                state = state.numpy()
                action = action.numpy()
                reward = reward.numpy()[0]
                next_state = next_state.numpy()
                done = True if done.numpy()[0] else False

                index_of_state = None
                for index, original_state in enumerate(states):
                    if np.array_equal(original_state, state):
                        index_of_state = index
                        break

                self.assertIn(state, states)
                self.assertIn(next_state, next_states)
                self.assertIn(action, actions)
                self.assertIn(reward, rewards)
                self.assertIn(done, dones)
                self.assertTrue(np.array_equal(state, states[index_of_state]))
                self.assertTrue(
                    np.array_equal(next_state, next_states[index_of_state]))
                self.assertTrue(np.array_equal(action,
                                               actions[index_of_state]))
                self.assertEqual(reward, rewards[index_of_state])
                self.assertEqual(done, dones[index_of_state])
Ejemplo n.º 5
0
class SACPolicy(Agent):
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        # print("LOADING THE PARAMS", policy_params, checkpoint_dir)
        self.policy_params = policy_params
        self.gamma = float(policy_params["gamma"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_update_rate = int(policy_params["critic_update_rate"])
        self.policy_update_rate = int(policy_params["policy_update_rate"])
        self.warmup = int(policy_params["warmup"])
        self.seed = int(policy_params["seed"])
        self.batch_size = int(policy_params["batch_size"])
        self.hidden_units = int(policy_params["hidden_units"])
        self.tau = float(policy_params["tau"])
        self.initial_alpha = float(policy_params["initial_alpha"])
        self.logging_freq = int(policy_params["logging_freq"])
        self.action_size = 2
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(
            policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"])
        self.reward_type = adapters.type_from_string(
            policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"SAC baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type.")
        if self.observation_type != adapters.AdapterType.DefaultObservationVector:
            raise Exception(
                f"SAC baseline only supports the "
                f"{adapters.AdapterType.DefaultObservationVector} observation type."
            )

        self.observation_space = adapters.space_from_type(
            self.observation_type)
        self.low_dim_states_size = self.observation_space[
            "low_dim_states"].shape[0]
        self.social_capacity = self.observation_space["social_vehicles"].shape[
            0]
        self.num_social_features = self.observation_space[
            "social_vehicles"].shape[1]

        self.encoder_key = policy_params["social_vehicles"]["encoder_key"]
        self.social_policy_hidden_units = int(
            policy_params["social_vehicles"].get("social_policy_hidden_units",
                                                 0))
        self.social_policy_init_std = int(policy_params["social_vehicles"].get(
            "social_policy_init_std", 0))
        self.social_vehicle_config = get_social_vehicle_configs(
            encoder_key=self.encoder_key,
            num_social_features=self.num_social_features,
            social_capacity=self.social_capacity,
            seed=self.seed,
            social_policy_hidden_units=self.social_policy_hidden_units,
            social_policy_init_std=self.social_policy_init_std,
        )
        self.social_vehicle_encoder = self.social_vehicle_config["encoder"]
        self.social_feature_encoder_class = self.social_vehicle_encoder[
            "social_feature_encoder_class"]
        self.social_feature_encoder_params = self.social_vehicle_encoder[
            "social_feature_encoder_params"]

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (policy_params["save_codes"]
                           if "save_codes" in policy_params else None)
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.current_iteration = 0
        self.steps = 0
        self.init_networks()
        if checkpoint_dir:
            self.load(checkpoint_dir)

    @property
    def state_size(self):
        # Adjusting state_size based on number of features (ego+social)
        size = self.low_dim_states_size
        if self.social_feature_encoder_class:
            size += self.social_feature_encoder_class(
                **self.social_feature_encoder_params).output_dim
        else:
            size += self.social_capacity * self.num_social_features
        # adding the previous action
        size += self.action_size
        return size

    def init_networks(self):
        self.sac_net = SACNetwork(
            action_size=self.action_size,
            state_size=self.state_size,
            hidden_units=self.hidden_units,
            seed=self.seed,
            initial_alpha=self.initial_alpha,
            social_feature_encoder_class=self.social_feature_encoder_class,
            social_feature_encoder_params=self.social_feature_encoder_params,
        ).to(self.device_name)

        self.actor_optimizer = torch.optim.Adam(
            self.sac_net.actor.parameters(), lr=self.actor_lr)

        self.critic_optimizer = torch.optim.Adam(
            self.sac_net.critic.parameters(), lr=self.critic_lr)

        self.log_alpha_optimizer = torch.optim.Adam([self.sac_net.log_alpha],
                                                    lr=self.critic_lr)

    def act(self, state, explore=True):
        state = copy.deepcopy(state)
        state["low_dim_states"] = np.float32(
            np.append(state["low_dim_states"], self.prev_action))
        state["social_vehicles"] = (torch.from_numpy(
            state["social_vehicles"]).unsqueeze(0).to(self.device))
        state["low_dim_states"] = (torch.from_numpy(
            state["low_dim_states"]).unsqueeze(0).to(self.device))

        action, _, mean = self.sac_net.sample(state)

        if explore:  # training mode
            action = torch.squeeze(action, 0)
            action = action.detach().cpu().numpy()
        else:  # testing mode
            mean = torch.squeeze(mean, 0)
            action = mean.detach().cpu().numpy()
        return to_3d_action(action)

    def step(self, state, action, reward, next_state, done, info):
        # dont treat timeout as done equal to True
        max_steps_reached = info["logs"]["events"].reached_max_episode_steps
        if max_steps_reached:
            done = False
        action = to_2d_action(action)
        self.memory.add(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=float(done),
            prev_action=self.prev_action,
        )
        self.steps += 1
        output = {}
        if self.steps > max(self.warmup, self.batch_size):
            states, actions, rewards, next_states, dones, others = self.memory.sample(
                device=self.device_name)
            if self.steps % self.critic_update_rate == 0:
                critic_loss = self.update_critic(states, actions, rewards,
                                                 next_states, dones)
                output["loss/critic_loss"] = {
                    "type": "scalar",
                    "data": critic_loss.item(),
                    "freq": 2,
                }

            if self.steps % self.policy_update_rate == 0:
                actor_loss, temp_loss = self.update_actor_temp(
                    states, actions, rewards, next_states, dones)
                output["loss/actor_loss"] = {
                    "type": "scalar",
                    "data": actor_loss.item(),
                    "freq": self.logging_freq,
                }
                output["loss/temp_loss"] = {
                    "type": "scalar",
                    "data": temp_loss.item(),
                    "freq": self.logging_freq,
                }
                output["others/alpha"] = {
                    "type": "scalar",
                    "data": self.sac_net.alpha.item(),
                    "freq": self.logging_freq,
                }
                self.current_iteration += 1
            self.target_soft_update(self.sac_net.critic, self.sac_net.target,
                                    self.tau)
        self.prev_action = action if not done else np.zeros(self.action_size)
        return output

    def update_critic(self, states, actions, rewards, next_states, dones):

        q1_current, q2_current, aux_losses = self.sac_net.critic(states,
                                                                 actions,
                                                                 training=True)
        with torch.no_grad():
            next_actions, log_probs, _ = self.sac_net.sample(next_states)
            q1_next, q2_next = self.sac_net.target(next_states, next_actions)
            v_next = (torch.min(q1_next, q2_next) -
                      self.sac_net.alpha.detach() * log_probs)
            q_target = (rewards + ((1 - dones) * self.gamma * v_next)).detach()

        critic_loss = F.mse_loss(q1_current, q_target) + F.mse_loss(
            q2_current, q_target)

        aux_losses = compute_sum_aux_losses(aux_losses)
        overall_loss = critic_loss + aux_losses
        self.critic_optimizer.zero_grad()
        overall_loss.backward()
        self.critic_optimizer.step()

        return critic_loss

    def update_actor_temp(self, states, actions, rewards, next_states, dones):

        for p in self.sac_net.target.parameters():
            p.requires_grad = False
        for p in self.sac_net.critic.parameters():
            p.requires_grad = False

        # update actor:
        actions, log_probs, aux_losses = self.sac_net.sample(states,
                                                             training=True)
        q1, q2 = self.sac_net.critic(states, actions)
        q_old = torch.min(q1, q2)
        actor_loss = (self.sac_net.alpha.detach() * log_probs - q_old).mean()
        aux_losses = compute_sum_aux_losses(aux_losses)
        overall_loss = actor_loss + aux_losses
        self.actor_optimizer.zero_grad()
        overall_loss.backward()
        self.actor_optimizer.step()

        # update temp:
        temp_loss = (self.sac_net.log_alpha.exp() *
                     (-log_probs.detach().mean() + self.action_size).detach())
        self.log_alpha_optimizer.zero_grad()
        temp_loss.backward()
        self.log_alpha_optimizer.step()
        self.sac_net.alpha.data = self.sac_net.log_alpha.exp().detach()

        for p in self.sac_net.target.parameters():
            p.requires_grad = True
        for p in self.sac_net.critic.parameters():
            p.requires_grad = True

        return actor_loss, temp_loss

    def target_soft_update(self, critic, target_critic, tau):
        with torch.no_grad():
            for critic_param, target_critic_param in zip(
                    critic.parameters(), target_critic.parameters()):
                target_critic_param.data = (
                    tau * critic_param.data +
                    (1 - tau) * target_critic_param.data)

    def load(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        map_location = None
        if self.device and self.device.type == "cpu":
            map_location = "cpu"
        self.sac_net.actor.load_state_dict(
            torch.load(model_dir / "actor.pth", map_location=map_location))
        self.sac_net.target.load_state_dict(
            torch.load(model_dir / "target.pth", map_location=map_location))
        self.sac_net.critic.load_state_dict(
            torch.load(model_dir / "critic.pth", map_location=map_location))
        print("<<<<<<< MODEL LOADED >>>>>>>>>", model_dir)

    def save(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        # with open(model_dir / "params.yaml", "w") as file:
        #     yaml.dump(policy_params, file)

        torch.save(self.sac_net.actor.state_dict(), model_dir / "actor.pth")
        torch.save(self.sac_net.target.state_dict(), model_dir / "target.pth")
        torch.save(self.sac_net.critic.state_dict(), model_dir / "critic.pth")
        print("<<<<<<< MODEL SAVED >>>>>>>>>", model_dir)

    def reset(self):
        pass
Ejemplo n.º 6
0
class TD3Policy(Agent):
    def __init__(
        self,
        policy_params=None,
        checkpoint_dir=None,
    ):
        self.policy_params = policy_params
        self.action_size = 2
        self.action_range = np.asarray([[-1.0, 1.0], [-1.0, 1.0]], dtype=np.float32)
        self.actor_lr = float(policy_params["actor_lr"])
        self.critic_lr = float(policy_params["critic_lr"])
        self.critic_wd = float(policy_params["critic_wd"])
        self.actor_wd = float(policy_params["actor_wd"])
        self.noise_clip = float(policy_params["noise_clip"])
        self.policy_noise = float(policy_params["policy_noise"])
        self.update_rate = int(policy_params["update_rate"])
        self.policy_delay = int(policy_params["policy_delay"])
        self.warmup = int(policy_params["warmup"])
        self.critic_tau = float(policy_params["critic_tau"])
        self.actor_tau = float(policy_params["actor_tau"])
        self.gamma = float(policy_params["gamma"])
        self.batch_size = int(policy_params["batch_size"])
        self.sigma = float(policy_params["sigma"])
        self.theta = float(policy_params["theta"])
        self.dt = float(policy_params["dt"])
        self.action_low = torch.tensor([[each[0] for each in self.action_range]])
        self.action_high = torch.tensor([[each[1] for each in self.action_range]])
        self.seed = int(policy_params["seed"])
        self.prev_action = np.zeros(self.action_size)
        self.action_type = adapters.type_from_string(policy_params["action_type"])
        self.observation_type = adapters.type_from_string(
            policy_params["observation_type"]
        )
        self.reward_type = adapters.type_from_string(policy_params["reward_type"])

        if self.action_type != adapters.AdapterType.DefaultActionContinuous:
            raise Exception(
                f"TD3 baseline only supports the "
                f"{adapters.AdapterType.DefaultActionContinuous} action type."
            )

        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            observation_space = adapters.space_from_type(self.observation_type)
            low_dim_states_size = observation_space["low_dim_states"].shape[0]
            social_capacity = observation_space["social_vehicles"].shape[0]
            num_social_features = observation_space["social_vehicles"].shape[1]

            # Get information to build the encoder.
            encoder_key = policy_params["social_vehicles"]["encoder_key"]
            social_policy_hidden_units = int(
                policy_params["social_vehicles"].get("social_policy_hidden_units", 0)
            )
            social_policy_init_std = int(
                policy_params["social_vehicles"].get("social_policy_init_std", 0)
            )
            social_vehicle_config = get_social_vehicle_configs(
                encoder_key=encoder_key,
                num_social_features=num_social_features,
                social_capacity=social_capacity,
                seed=self.seed,
                social_policy_hidden_units=social_policy_hidden_units,
                social_policy_init_std=social_policy_init_std,
            )
            social_vehicle_encoder = social_vehicle_config["encoder"]
            social_feature_encoder_class = social_vehicle_encoder[
                "social_feature_encoder_class"
            ]
            social_feature_encoder_params = social_vehicle_encoder[
                "social_feature_encoder_params"
            ]

            # Calculate the state size based on the number of features (ego + social).
            state_size = low_dim_states_size
            if social_feature_encoder_class:
                state_size += social_feature_encoder_class(
                    **social_feature_encoder_params
                ).output_dim
            else:
                state_size += social_capacity * num_social_features
            # Add the action size to account for the previous action.
            state_size += self.action_size

            actor_network_class = FCActorNetwork
            critic_network_class = FCCrtiicNetwork
            network_params = {
                "state_space": state_size,
                "action_space": self.action_size,
                "seed": self.seed,
                "social_feature_encoder": social_feature_encoder_class(
                    **social_feature_encoder_params
                )
                if social_feature_encoder_class
                else None,
            }
        elif self.observation_type == adapters.AdapterType.DefaultObservationImage:
            observation_space = adapters.space_from_type(self.observation_type)
            stack_size = observation_space.shape[0]
            image_shape = (observation_space.shape[1], observation_space.shape[2])

            actor_network_class = CNNActorNetwork
            critic_network_class = CNNCriticNetwork
            network_params = {
                "input_channels": stack_size,
                "input_dimension": image_shape,
                "action_size": self.action_size,
                "seed": self.seed,
            }
        else:
            raise Exception(
                f"TD3 baseline does not support the '{self.observation_type}' "
                f"observation type."
            )

        # others
        self.checkpoint_dir = checkpoint_dir
        self.device_name = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.device_name)
        self.save_codes = (
            policy_params["save_codes"] if "save_codes" in policy_params else None
        )
        self.memory = ReplayBuffer(
            buffer_size=int(policy_params["replay_buffer"]["buffer_size"]),
            batch_size=int(policy_params["replay_buffer"]["batch_size"]),
            observation_type=self.observation_type,
            device_name=self.device_name,
        )
        self.num_actor_updates = 0
        self.current_iteration = 0
        self.step_count = 0
        self.init_networks(actor_network_class, critic_network_class, network_params)
        if checkpoint_dir:
            self.load(checkpoint_dir)

    def init_networks(self, actor_network_class, critic_network_class, network_params):
        self.noise = [
            OrnsteinUhlenbeckProcess(
                size=(1,), theta=0.01, std=LinearSchedule(0.25), mu=0.0, x0=0.0, dt=1.0
            ),  # throttle
            OrnsteinUhlenbeckProcess(
                size=(1,), theta=0.1, std=LinearSchedule(0.05), mu=0.0, x0=0.0, dt=1.0
            ),  # steering
        ]

        self.actor = actor_network_class(**network_params).to(self.device)
        self.actor_target = actor_network_class(**network_params).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=self.actor_lr)

        self.critic_1 = critic_network_class(**network_params).to(self.device)
        self.critic_1_target = critic_network_class(**network_params).to(self.device)
        self.critic_1_target.load_state_dict(self.critic_1.state_dict())
        self.critic_1_optimizer = optim.Adam(
            self.critic_1.parameters(), lr=self.critic_lr
        )

        self.critic_2 = critic_network_class(**network_params).to(self.device)
        self.critic_2_target = critic_network_class(**network_params).to(self.device)
        self.critic_2_target.load_state_dict(self.critic_2.state_dict())
        self.critic_2_optimizer = optim.Adam(
            self.critic_2.parameters(), lr=self.critic_lr
        )

    def act(self, state, explore=True):
        state = copy.deepcopy(state)
        if self.observation_type == adapters.AdapterType.DefaultObservationVector:
            # Default vector observation type.
            state["low_dim_states"] = np.float32(
                np.append(state["low_dim_states"], self.prev_action)
            )
            state["social_vehicles"] = (
                torch.from_numpy(state["social_vehicles"]).unsqueeze(0).to(self.device)
            )
            state["low_dim_states"] = (
                torch.from_numpy(state["low_dim_states"]).unsqueeze(0).to(self.device)
            )
        else:
            # Default image observation type.
            state = torch.from_numpy(state).unsqueeze(0).to(self.device)

        self.actor.eval()

        action = self.actor(state).cpu().data.numpy().flatten()

        noise = [self.noise[0].sample(), self.noise[1].sample()]
        if explore:
            action[0] += noise[0]
            action[1] += noise[1]

        self.actor.train()
        action_low, action_high = (
            self.action_low.data.cpu().numpy(),
            self.action_high.data.cpu().numpy(),
        )
        action = np.clip(action, action_low, action_high)[0]

        return to_3d_action(action)

    def step(self, state, action, reward, next_state, done, info):
        # dont treat timeout as done equal to True
        max_steps_reached = info["logs"]["events"].reached_max_episode_steps
        reset_noise = False
        if max_steps_reached:
            done = False
            reset_noise = True

        output = {}
        action = to_2d_action(action)
        self.memory.add(
            state=state,
            action=action,
            reward=reward,
            next_state=next_state,
            done=float(done),
            prev_action=self.prev_action,
        )
        self.step_count += 1
        if reset_noise:
            self.reset()
        if (
            len(self.memory) > max(self.batch_size, self.warmup)
            and (self.step_count + 1) % self.update_rate == 0
        ):
            output = self.learn()

        self.prev_action = action if not done else np.zeros(self.action_size)
        return output

    def reset(self):
        self.noise[0].reset_states()
        self.noise[1].reset_states()

    def learn(self):
        output = {}
        states, actions, rewards, next_states, dones, others = self.memory.sample(
            device=self.device
        )
        actions = actions.squeeze(dim=1)
        next_actions = self.actor_target(next_states)
        noise = torch.randn_like(next_actions).mul(self.policy_noise)
        noise = noise.clamp(-self.noise_clip, self.noise_clip)
        next_actions += noise
        next_actions = torch.max(
            torch.min(next_actions, self.action_high.to(self.device)),
            self.action_low.to(self.device),
        )

        target_Q1 = self.critic_1_target(next_states, next_actions)
        target_Q2 = self.critic_2_target(next_states, next_actions)
        target_Q = torch.min(target_Q1, target_Q2)
        target_Q = (rewards + ((1 - dones) * self.gamma * target_Q)).detach()

        # Optimize Critic 1:
        current_Q1, aux_losses_Q1 = self.critic_1(states, actions, training=True)
        loss_Q1 = F.mse_loss(current_Q1, target_Q) + compute_sum_aux_losses(
            aux_losses_Q1
        )
        self.critic_1_optimizer.zero_grad()
        loss_Q1.backward()
        self.critic_1_optimizer.step()

        # Optimize Critic 2:
        current_Q2, aux_losses_Q2 = self.critic_2(states, actions, training=True)
        loss_Q2 = F.mse_loss(current_Q2, target_Q) + compute_sum_aux_losses(
            aux_losses_Q2
        )
        self.critic_2_optimizer.zero_grad()
        loss_Q2.backward()
        self.critic_2_optimizer.step()

        # delayed actor updates
        if (self.step_count + 1) % self.policy_delay == 0:
            critic_out = self.critic_1(states, self.actor(states), training=True)
            actor_loss, actor_aux_losses = -critic_out[0], critic_out[1]
            actor_loss = actor_loss.mean() + compute_sum_aux_losses(actor_aux_losses)
            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()
            self.soft_update(self.actor_target, self.actor, self.actor_tau)

            self.num_actor_updates += 1
            output = {
                "loss/critic_1": {
                    "type": "scalar",
                    "data": loss_Q1.data.cpu().numpy(),
                    "freq": 10,
                },
                "loss/actor": {
                    "type": "scalar",
                    "data": actor_loss.data.cpu().numpy(),
                    "freq": 10,
                },
            }
        self.soft_update(self.critic_1_target, self.critic_1, self.critic_tau)
        self.soft_update(self.critic_2_target, self.critic_2, self.critic_tau)
        self.current_iteration += 1
        return output

    def soft_update(self, target, src, tau):
        for target_param, param in zip(target.parameters(), src.parameters()):
            target_param.detach_()
            target_param.copy_(target_param * (1.0 - tau) + param * tau)

    def load(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        map_location = None
        if self.device and self.device.type == "cpu":
            map_location = "cpu"
        self.actor.load_state_dict(
            torch.load(model_dir / "actor.pth", map_location=map_location)
        )
        self.actor_target.load_state_dict(
            torch.load(model_dir / "actor_target.pth", map_location=map_location)
        )
        self.critic_1.load_state_dict(
            torch.load(model_dir / "critic_1.pth", map_location=map_location)
        )
        self.critic_1_target.load_state_dict(
            torch.load(model_dir / "critic_1_target.pth", map_location=map_location)
        )
        self.critic_2.load_state_dict(
            torch.load(model_dir / "critic_2.pth", map_location=map_location)
        )
        self.critic_2_target.load_state_dict(
            torch.load(model_dir / "critic_2_target.pth", map_location=map_location)
        )

    def save(self, model_dir):
        model_dir = pathlib.Path(model_dir)
        torch.save(self.actor.state_dict(), model_dir / "actor.pth")
        torch.save(
            self.actor_target.state_dict(),
            model_dir / "actor_target.pth",
        )
        torch.save(self.critic_1.state_dict(), model_dir / "critic_1.pth")
        torch.save(
            self.critic_1_target.state_dict(),
            model_dir / "critic_1_target.pth",
        )
        torch.save(self.critic_2.state_dict(), model_dir / "critic_2.pth")
        torch.save(
            self.critic_2_target.state_dict(),
            model_dir / "critic_2_target.pth",
        )