Exemple #1
0
    def __call__(
        self,
        model: Brain,
        target_model: Brain,
        experiences: Tuple[torch.Tensor, ...],
        gamma: float,
        head_cfg: ConfigDict,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return element-wise C51 loss and Q-values."""
        states, actions, rewards, next_states, dones = experiences[:5]
        batch_size = states.shape[0]

        support = torch.linspace(
            head_cfg.configs.v_min, head_cfg.configs.v_max, head_cfg.configs.atom_size
        ).to(device)
        delta_z = float(head_cfg.configs.v_max - head_cfg.configs.v_min) / (
            head_cfg.configs.atom_size - 1
        )

        with torch.no_grad():
            # According to noisynet paper,
            # it resamples noisynet parameters on online network when using double q
            # but we don't because there is no remarkable difference in performance.
            next_actions = model.forward_(next_states)[1].argmax(1)

            next_dist = target_model.forward_(next_states)[0]
            next_dist = next_dist[range(batch_size), next_actions]

            t_z = rewards + (1 - dones) * gamma * support
            t_z = t_z.clamp(min=head_cfg.configs.v_min, max=head_cfg.configs.v_max)
            b = (t_z - head_cfg.configs.v_min) / delta_z
            l = b.floor().long()  # noqa: E741
            u = b.ceil().long()

            offset = (
                torch.linspace(
                    0, (batch_size - 1) * head_cfg.configs.atom_size, batch_size
                )
                .long()
                .unsqueeze(1)
                .expand(batch_size, head_cfg.configs.atom_size)
                .to(device)
            )

            proj_dist = torch.zeros(next_dist.size(), device=device)
            proj_dist.view(-1).index_add_(
                0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1)
            )
            proj_dist.view(-1).index_add_(
                0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1)
            )

        dist, q_values = model.forward_(states)
        log_p = torch.log(
            torch.clamp(dist[range(batch_size), actions.long()], min=1e-7)
        )

        dq_loss_element_wise = -(proj_dist * log_p).sum(1, keepdim=True)

        return dq_loss_element_wise, q_values
Exemple #2
0
    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        self.actor = Brain(self.backbone_cfg.actor,
                           self.head_cfg.actor).to(self.device)
        self.actor_target = Brain(self.backbone_cfg.actor,
                                  self.head_cfg.actor).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())

        # create critic
        self.critic = Brain(self.backbone_cfg.critic,
                            self.head_cfg.critic).to(self.device)
        self.critic_target = Brain(self.backbone_cfg.critic,
                                   self.head_cfg.critic).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load the optimizer and model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)
def test_brain():
    """Test wheter brain make fc layer based on backbone's output size."""

    head_cfg.configs.state_size = test_state_dim
    head_cfg.configs.output_size = 8

    model = Brain(resnet_cfg, head_cfg)
    assert model.head.input_size == 16384
Exemple #4
0
    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        self.actor = Brain(self.backbone_cfg.actor,
                           self.head_cfg.actor).to(self.device)

        # create v_critic
        self.vf = Brain(self.backbone_cfg.critic_vf,
                        self.head_cfg.critic_vf).to(self.device)
        self.vf_target = Brain(self.backbone_cfg.critic_vf,
                               self.head_cfg.critic_vf).to(self.device)
        self.vf_target.load_state_dict(self.vf.state_dict())

        # create q_critic
        self.qf_1 = Brain(self.backbone_cfg.critic_qf,
                          self.head_cfg.critic_qf).to(self.device)
        self.qf_2 = Brain(self.backbone_cfg.critic_qf,
                          self.head_cfg.critic_qf).to(self.device)

        # create optimizers
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )
        self.vf_optim = optim.Adam(
            self.vf.parameters(),
            lr=self.optim_cfg.lr_vf,
            weight_decay=self.optim_cfg.weight_decay,
        )
        self.qf_1_optim = optim.Adam(
            self.qf_1.parameters(),
            lr=self.optim_cfg.lr_qf1,
            weight_decay=self.optim_cfg.weight_decay,
        )
        self.qf_2_optim = optim.Adam(
            self.qf_2.parameters(),
            lr=self.optim_cfg.lr_qf2,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load the optimizer and model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)
 def _synchronize(self, network: Brain, new_state_dict: Dict[str,
                                                             np.ndarray]):
     """Copy parameters from numpy arrays."""
     param_name_list = list(new_state_dict.keys())
     for worker_named_param in network.named_parameters():
         worker_param_name = worker_named_param[0]
         if worker_param_name in param_name_list:
             new_param = numpy2floattensor(
                 new_state_dict[worker_param_name], self.device)
             worker_named_param[1].data.copy_(new_param)
    def __init__(
        self,
        args: argparse.Namespace,
        env_info: ConfigDict,
        log_cfg: ConfigDict,
        comm_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
    ):
        self.args = args
        self.env_info = env_info
        self.log_cfg = log_cfg
        self.comm_cfg = comm_cfg
        self.device = torch.device("cpu")  # Logger only runs on cpu
        self.brain = Brain(backbone, head).to(self.device)

        self.update_step = 0
        self.log_info_queue = deque(maxlen=100)

        self._init_env()
Exemple #7
0
    def _init_network(self):
        """Initialize networks and optimizers."""
        self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device)
        self.dqn_target = Brain(self.backbone_cfg,
                                self.head_cfg).to(self.device)
        self.loss_fn = build_loss(self.loss_type)

        self.dqn_target.load_state_dict(self.dqn.state_dict())

        # create optimizer
        self.dqn_optim = optim.Adam(
            self.dqn.parameters(),
            lr=self.optim_cfg.lr_dqn,
            weight_decay=self.optim_cfg.weight_decay,
            eps=self.optim_cfg.adam_eps,
        )

        # load the optimizer and model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)
Exemple #8
0
    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        if self.backbone_cfg.shared_actor_critic:
            shared_backbone = build_backbone(
                self.backbone_cfg.shared_actor_critic)
            self.actor = Brain(
                self.backbone_cfg.shared_actor_critic,
                self.head_cfg.actor,
                shared_backbone,
            )
            self.critic = Brain(
                self.backbone_cfg.shared_actor_critic,
                self.head_cfg.critic,
                shared_backbone,
            )
            self.actor = self.actor.to(self.device)
            self.critic = self.critic.to(self.device)
        else:
            self.actor = Brain(self.backbone_cfg.actor,
                               self.head_cfg.actor).to(self.device)
            self.critic = Brain(self.backbone_cfg.critic,
                                self.head_cfg.critic).to(self.device)
        self.discriminator = Discriminator(
            self.backbone_cfg.discriminator,
            self.head_cfg.discriminator,
            self.head_cfg.aciton_embedder,
        ).to(self.device)

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.discriminator_optim = optim.Adam(
            self.discriminator.parameters(),
            lr=self.optim_cfg.lr_discriminator,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load model parameters
        if self.load_from is not None:
            self.load_params(self.load_from)
Exemple #9
0
    def _init_network(self):
        """Initialize networks and optimizers."""
        self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device)
        self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(
            self.device
        )

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        if self.load_from is not None:
            self.load_params(self.load_from)
Exemple #10
0
    def __init__(
        self,
        log_cfg: ConfigDict,
        comm_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        env_name: str,
        is_atari: bool,
        state_size: int,
        output_size: int,
        max_update_step: int,
        episode_num: int,
        max_episode_steps: int,
        interim_test_num: int,
        is_log: bool,
        is_render: bool,
    ):
        self.log_cfg = log_cfg
        self.comm_cfg = comm_cfg
        self.device = torch.device("cpu")  # Logger only runs on cpu
        head.configs.state_size = state_size
        head.configs.output_size = output_size
        self.brain = Brain(backbone, head).to(self.device)

        self.env_name = env_name
        self.is_atari = is_atari
        self.max_update_step = max_update_step
        self.episode_num = episode_num
        self.max_episode_steps = max_episode_steps
        self.interim_test_num = interim_test_num
        self.is_log = is_log
        self.is_render = is_render

        self.update_step = 0
        self.log_info_queue = deque(maxlen=100)

        self._init_env()
Exemple #11
0
    def _init_network(self):
        """Initialize network and optimizer."""
        self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device)
        self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(
            self.device
        )
        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps
        )
        self.critic_optim = optim.Adam(
            self.critic.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps
        )

        self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(
            self.device
        )
        self.actor_target.load_state_dict(self.actor.state_dict())

        if self.load_from is not None:
            self.load_params(self.load_from)
Exemple #12
0
class DQNWorker(DistributedWorker):
    """DQN worker for distributed training.

    Attributes:
        backbone (ConfigDict): backbone configs for building network
        head (ConfigDict): head configs for building network
        state_dict (ConfigDict): initial network state dict received form learner
        device (str): literal to indicate cpu/cuda use

    """
    def __init__(
        self,
        rank: int,
        args: argparse.Namespace,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        state_dict: OrderedDict,
        device: str,
        loss_type: ConfigDict,
    ):
        DistributedWorker.__init__(self, rank, args, env_info, hyper_params,
                                   device)
        self.loss_fn = build_loss(loss_type)
        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.configs.state_size = self.env_info.observation_space.shape
        self.head_cfg.configs.output_size = self.env_info.action_space.n

        self.use_n_step = self.hyper_params.n_step > 1

        self.max_epsilon = self.hyper_params.max_epsilon
        self.min_epsilon = self.hyper_params.min_epsilon
        self.epsilon = self.hyper_params.max_epsilon

        self._init_networks(state_dict)

    # pylint: disable=attribute-defined-outside-init
    def _init_networks(self, state_dict: OrderedDict):
        """Initialize DQN policy with learner state dict."""
        self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device)
        self.dqn.load_state_dict(state_dict)
        self.dqn.eval()

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        DistributedWorker.load_params(self, path)

        params = torch.load(path)
        self.dqn.load_state_dict(params["dqn_state_dict"])
        print("[INFO] loaded the model and optimizer from", path)

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input space."""
        # epsilon greedy policy
        # pylint: disable=comparison-with-callable
        if self.epsilon > np.random.random():
            selected_action = np.array(self.env.action_space.sample())
        else:
            with torch.no_grad():
                state = self._preprocess_state(state, self.device)
                selected_action = self.dqn(state).argmax()
            selected_action = selected_action.cpu().numpy()

        # Decay epsilon
        self.epsilon = max(
            self.epsilon - (self.max_epsilon - self.min_epsilon) *
            self.hyper_params.epsilon_decay,
            self.min_epsilon,
        )

        return selected_action

    def step(self,
             action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]:
        """Take an action and return the response of the env."""
        next_state, reward, done, info = self.env.step(action)
        return next_state, reward, done, info

    def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray:
        """Compute initial priority values of experiences in local memory."""
        states = numpy2floattensor(memory["states"], self.device)
        actions = numpy2floattensor(memory["actions"], self.device).long()
        rewards = numpy2floattensor(memory["rewards"].reshape(-1, 1),
                                    self.device)
        next_states = numpy2floattensor(memory["next_states"], self.device)
        dones = numpy2floattensor(memory["dones"].reshape(-1, 1), self.device)
        memory_tensors = (states, actions, rewards, next_states, dones)

        with torch.no_grad():
            dq_loss_element_wise, _ = self.loss_fn(
                self.dqn,
                self.dqn,
                memory_tensors,
                self.hyper_params.gamma,
                self.head_cfg,
            )
        loss_for_prior = dq_loss_element_wise.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.hyper_params.per_eps
        return new_priorities

    def synchronize(self, new_state_dict: Dict[str, np.ndarray]):
        """Synchronize worker dqn with learner dqn."""
        self._synchronize(self.dqn, new_state_dict)
Exemple #13
0
class DDPGLearner(Learner):
    """Learner for DDPG Agent.

    Attributes:
        args (argparse.Namespace): arguments including hyperparameters and training settings
        hyper_params (ConfigDict): hyper-parameters
        optim_cfg (ConfigDict): config of optimizer
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        actor (nn.Module): actor model to select actions
        actor_target (nn.Module): target actor model to select actions
        critic (nn.Module): critic model to predict state values
        critic_target (nn.Module): target critic model to predict state values
        actor_optim (Optimizer): optimizer for training actor
        critic_optim (Optimizer): optimizer for training critic

    """
    def __init__(
        self,
        args: argparse.Namespace,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
        noise_cfg: ConfigDict,
        device: torch.device,
    ):
        Learner.__init__(self, args, env_info, hyper_params, log_cfg, device)

        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.critic.configs.state_size = (
            self.env_info.observation_space.shape[0] +
            self.env_info.action_space.shape[0], )
        self.head_cfg.actor.configs.state_size = self.env_info.observation_space.shape
        self.head_cfg.actor.configs.output_size = self.env_info.action_space.shape[
            0]
        self.optim_cfg = optim_cfg
        self.noise_cfg = noise_cfg

        self._init_network()

    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        self.actor = Brain(self.backbone_cfg.actor,
                           self.head_cfg.actor).to(self.device)
        self.actor_target = Brain(self.backbone_cfg.actor,
                                  self.head_cfg.actor).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())

        # create critic
        self.critic = Brain(self.backbone_cfg.critic,
                            self.head_cfg.critic).to(self.device)
        self.critic_target = Brain(self.backbone_cfg.critic,
                                   self.head_cfg.critic).to(self.device)
        self.critic_target.load_state_dict(self.critic.state_dict())

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load the optimizer and model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)

    def update_model(
            self, experience: Tuple[torch.Tensor,
                                    ...]) -> Tuple[torch.Tensor, ...]:
        """Update actor and critic networks."""
        states, actions, rewards, next_states, dones = experience

        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        masks = 1 - dones
        next_actions = self.actor_target(next_states)
        next_values = self.critic_target(
            torch.cat((next_states, next_actions), dim=-1))
        curr_returns = rewards + self.hyper_params.gamma * next_values * masks
        curr_returns = curr_returns.to(self.device)

        # train critic
        gradient_clip_ac = self.hyper_params.gradient_clip_ac
        gradient_clip_cr = self.hyper_params.gradient_clip_cr

        values = self.critic(torch.cat((states, actions), dim=-1))
        critic_loss = F.mse_loss(values, curr_returns)
        self.critic_optim.zero_grad()
        critic_loss.backward()
        clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
        self.critic_optim.step()

        # train actor
        actions = self.actor(states)
        actor_loss = -self.critic(torch.cat((states, actions), dim=-1)).mean()
        self.actor_optim.zero_grad()
        actor_loss.backward()
        clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
        self.actor_optim.step()

        # update target networks
        common_utils.soft_update(self.actor, self.actor_target,
                                 self.hyper_params.tau)
        common_utils.soft_update(self.critic, self.critic_target,
                                 self.hyper_params.tau)

        return actor_loss.item(), critic_loss.item()

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "actor_state_dict": self.actor.state_dict(),
            "actor_target_state_dict": self.actor_target.state_dict(),
            "critic_state_dict": self.critic.state_dict(),
            "critic_target_state_dict": self.critic_target.state_dict(),
            "actor_optim_state_dict": self.actor_optim.state_dict(),
            "critic_optim_state_dict": self.critic_optim.state_dict(),
        }
        Learner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        Learner.load_params(self, path)

        params = torch.load(path)
        self.actor.load_state_dict(params["actor_state_dict"])
        self.actor_target.load_state_dict(params["actor_target_state_dict"])
        self.critic.load_state_dict(params["critic_state_dict"])
        self.critic_target.load_state_dict(params["critic_target_state_dict"])
        self.actor_optim.load_state_dict(params["actor_optim_state_dict"])
        self.critic_optim.load_state_dict(params["critic_optim_state_dict"])
        print("[INFO] loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (self.critic_target.state_dict(), self.actor.state_dict())

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection."""
        return self.actor
Exemple #14
0
class DistributedLogger(ABC):
    """Base class for loggers use in distributed training.

    Attributes:
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        comm_config (ConfigDict): configs for communication
        backbone (ConfigDict): backbone configs for building network
        head (ConfigDict): head configs for building network
        brain (Brain): logger brain for evaluation
        update_step (int): tracker for learner update step
        device (torch.device): device, cpu by default
        log_info_queue (deque): queue for storing log info received from learner
        env (gym.Env): gym environment for running test

    """
    def __init__(
        self,
        log_cfg: ConfigDict,
        comm_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        env_name: str,
        is_atari: bool,
        state_size: int,
        output_size: int,
        max_update_step: int,
        episode_num: int,
        max_episode_steps: int,
        interim_test_num: int,
        is_log: bool,
        is_render: bool,
    ):
        self.log_cfg = log_cfg
        self.comm_cfg = comm_cfg
        self.device = torch.device("cpu")  # Logger only runs on cpu
        head.configs.state_size = state_size
        head.configs.output_size = output_size
        self.brain = Brain(backbone, head).to(self.device)

        self.env_name = env_name
        self.is_atari = is_atari
        self.max_update_step = max_update_step
        self.episode_num = episode_num
        self.max_episode_steps = max_episode_steps
        self.interim_test_num = interim_test_num
        self.is_log = is_log
        self.is_render = is_render

        self.update_step = 0
        self.log_info_queue = deque(maxlen=100)

        self._init_env()

    # pylint: disable=attribute-defined-outside-init
    def _init_env(self):
        """Initialize gym environment."""
        if self.is_atari:
            self.env = atari_env_generator(self.env_name,
                                           self.max_episode_steps)
        else:
            self.env = gym.make(self.env_name)
            self.env, self.max_episode_steps = env_utils.set_env(
                self.env, self.max_episode_steps)

    @abstractmethod
    def load_params(self, path: str):
        if not os.path.exists(path):
            raise Exception(
                f"[ERROR] the input path does not exist. Wrong path: {path}")

    # pylint: disable=attribute-defined-outside-init
    def init_communication(self):
        """Initialize inter-process communication sockets."""
        ctx = zmq.Context()
        self.pull_socket = ctx.socket(zmq.PULL)
        self.pull_socket.bind(
            f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}")

    @abstractmethod
    def select_action(self, state: np.ndarray):
        pass

    @abstractmethod
    def write_log(self, log_value: dict):
        pass

    # pylint: disable=no-self-use
    @staticmethod
    def _preprocess_state(state: np.ndarray,
                          device: torch.device) -> torch.Tensor:
        state = numpy2floattensor(state, device)
        return state

    def set_wandb(self):
        """Set configuration for wandb logging."""
        wandb.init(
            project=self.env_name,
            name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}",
        )
        additional_log = dict(
            episode_num=self.episode_num,
            max_episode_steps=self.max_episode_steps,
        )
        wandb.config.update(additional_log)
        shutil.copy(self.log_cfg.cfg_path,
                    os.path.join(wandb.run.dir, "config.yaml"))

    def recv_log_info(self):
        """Receive info from learner."""
        received = False
        try:
            log_info_id = self.pull_socket.recv(zmq.DONTWAIT)
            received = True
        except zmq.Again:
            pass

        if received:
            self.log_info_queue.append(log_info_id)

    def run(self):
        """Run main logging loop; continuously receive data and log."""
        if self.is_log:
            self.set_wandb()

        while self.update_step < self.max_update_step:
            self.recv_log_info()
            if self.log_info_queue:  # if non-empty
                log_info_id = self.log_info_queue.pop()
                log_info = pa.deserialize(log_info_id)
                state_dict = log_info["state_dict"]
                log_value = log_info["log_value"]
                self.update_step = log_value["update_step"]

                self.synchronize(state_dict)
                avg_score = self.test(self.update_step)
                log_value["avg_score"] = avg_score
                self.write_log(log_value)

    def write_worker_log(self, worker_logs: List[dict],
                         worker_update_interval: int):
        """Log the mean scores of each episode per update step to wandb."""
        # NOTE: Worker plots are passed onto wandb.log as matplotlib.pyplot
        #       since wandb doesn't support logging multiple lines to single plot
        self.set_wandb()
        # Plot individual workers
        fig = go.Figure()
        worker_id = 0
        for worker_log in worker_logs:
            fig.add_trace(
                go.Scatter(
                    x=list(worker_log.keys()),
                    y=smoothen_graph(list(worker_log.values())),
                    mode="lines",
                    name=f"Worker {worker_id}",
                    line=dict(width=2),
                ))
            worker_id = worker_id + 1

        # Plot mean scores
        logged_update_steps = list(
            range(0, self.max_update_step + 1, worker_update_interval))

        mean_scores = []
        try:
            for step in logged_update_steps:
                scores_for_step = []
                for worker_log in worker_logs:
                    if step in list(worker_log):
                        scores_for_step.append(worker_log[step])
                mean_scores.append(np.mean(scores_for_step))
        except Exception as e:
            print(f"[Error] {e}")

        fig.add_trace(
            go.Scatter(
                x=logged_update_steps,
                y=mean_scores,
                mode="lines+markers",
                name="Mean scores",
                line=dict(width=5),
            ))

        # Write to wandb
        wandb.log({"Worker scores": fig})

    def test(self, update_step: int, interim_test: bool = True):
        """Test the agent."""
        avg_score = self._test(update_step, interim_test)

        # termination
        self.env.close()
        return avg_score

    def _test(self, update_step: int, interim_test: bool) -> float:
        """Common test routine."""
        if interim_test:
            test_num = self.interim_test_num
        else:
            test_num = self.episode_num

        self.brain.eval()
        scores = []
        for i_episode in range(test_num):
            state = self.env.reset()
            done = False
            score = 0
            step = 0

            while not done:
                if self.is_render:
                    self.env.render()

                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action)

                state = next_state
                score += reward
                step += 1

            scores.append(score)

            if interim_test:
                print(
                    "[INFO] update step: %d\ttest %d\tstep: %d\ttotal score: %d"
                    % (update_step, i_episode, step, score))
            else:
                print("[INFO] test %d\tstep: %d\ttotal score: %d" %
                      (i_episode, step, score))

        return np.mean(scores)

    def synchronize(self, state_dict: Dict[str, np.ndarray]):
        """Copy parameters from numpy arrays."""
        param_name_list = list(state_dict.keys())
        for logger_named_param in self.brain.named_parameters():
            logger_param_name = logger_named_param[0]
            if logger_param_name in param_name_list:
                new_param = numpy2floattensor(state_dict[logger_param_name],
                                              self.device)
                logger_named_param[1].data.copy_(new_param)
Exemple #15
0
class ACERLearner(Learner):
    """Learner for ACER Agent.

    Attributes:
        args (argparse.Namespace): arguments including hyperparameters and training settings
        hyper_params (ConfigDict): hyper-parameters
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        model (nn.Module): model to select actions and predict values
        model_optim (Optimizer): optimizer for training model

    """

    def __init__(
        self,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
        trust_region: ConfigDict,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        env_info: ConfigDict,
        is_test: bool,
        load_from: str,
    ):
        Learner.__init__(self, hyper_params, log_cfg, env_info.name, is_test)

        self.backbone_cfg = backbone
        self.head_cfg = head
        self.load_from = load_from
        self.head_cfg.actor.configs.state_size = env_info.observation_space.shape
        self.head_cfg.critic.configs.state_size = env_info.observation_space.shape
        self.head_cfg.actor.configs.output_size = env_info.action_space.n
        self.head_cfg.critic.configs.output_size = env_info.action_space.n
        self.optim_cfg = optim_cfg
        self.gradient_clip = hyper_params.gradient_clip
        self.trust_region = trust_region

        self._init_network()

    def _init_network(self):
        """Initialize network and optimizer."""
        self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device)
        self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(
            self.device
        )
        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps
        )
        self.critic_optim = optim.Adam(
            self.critic.parameters(), lr=self.optim_cfg.lr, eps=self.optim_cfg.adam_eps
        )

        self.actor_target = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(
            self.device
        )
        self.actor_target.load_state_dict(self.actor.state_dict())

        if self.load_from is not None:
            self.load_params(self.load_from)

    def update_model(self, experience: Tuple) -> torch.Tensor:

        state, action, reward, prob, done = experience
        state = state.to(self.device)
        reward = reward.to(self.device)
        action = action.to(self.device)
        prob = prob.to(self.device).squeeze()
        done = done.to(self.device)

        pi = F.softmax(self.actor(state), 1)

        q = self.critic(state)
        q_i = q.gather(1, action)
        pi_i = pi.gather(1, action)

        with torch.no_grad():
            v = (q * pi).sum(1).unsqueeze(1)
            rho = pi / (prob + 1e-8)
        rho_i = rho.gather(1, action)
        rho_bar = rho_i.clamp(max=self.hyper_params.c)

        q_ret = self.q_retrace(
            reward, done, q_i, v, rho_bar, self.hyper_params.gamma
        ).to(self.device)

        loss_f = -rho_bar * torch.log(pi_i + 1e-8) * (q_ret - v)
        loss_bc = (
            -(1 - (self.hyper_params.c / rho)).clamp(min=0)
            * pi.detach()
            * torch.log(pi + 1e-8)
            * (q.detach() - v)
        )

        value_loss = torch.sqrt((q_i - q_ret).pow(2)).mean() * 0.5

        if self.trust_region.use_trust_region:
            g = loss_f + loss_bc
            pi_target = F.softmax(self.actor_target(state), 1)
            # gradient of partial Q KL(P || Q) = - P / Q
            k = -pi_target / (pi + 1e-8)
            k_dot_g = k * g
            tr = (
                g
                - ((k_dot_g - self.trust_region.delta) / torch.norm(k)).clamp(max=0) * k
            )
            loss = tr.mean() + value_loss
        else:
            loss = loss_f.mean() + loss_bc.sum(1).mean() + value_loss

        self.actor_optim.zero_grad()
        self.critic_optim.zero_grad()
        loss.backward()

        nn.utils.clip_grad_norm_(self.actor.parameters(), self.gradient_clip)
        nn.utils.clip_grad_norm_(self.critic.parameters(), self.gradient_clip)
        for name, param in self.actor.named_parameters():
            if not torch.isfinite(param.grad).all():
                print(name, torch.isfinite(param.grad).all())
                print("Warning : Gradient is infinite. Do not update gradient.")
                return loss
        for name, param in self.critic.named_parameters():
            if not torch.isfinite(param.grad).all():
                print(name, torch.isfinite(param.grad).all())
                print("Warning : Gradient is infinite. Do not update gradient.")
                return loss
        self.actor_optim.step()
        self.critic_optim.step()

        common_utils.soft_update(self.actor, self.actor_target, self.hyper_params.tau)

        return loss

    @staticmethod
    def q_retrace(
        reward: torch.Tensor,
        done: torch.Tensor,
        q_a: torch.Tensor,
        v: torch.Tensor,
        rho_bar: torch.Tensor,
        gamma: float,
    ):
        """Calculate Q retrace."""
        q_ret = v[-1]
        q_ret_lst = []

        for i in reversed(range(len(reward))):
            q_ret = reward[i] + gamma * q_ret * done[i]
            q_ret_lst.append(q_ret.item())
            q_ret = rho_bar[i] * (q_ret - q_a[i]) + v[i]

        q_ret_lst.reverse()
        q_ret = torch.FloatTensor(q_ret_lst).unsqueeze(1)
        return q_ret

    def save_params(self, n_episode: int):
        params = {
            "actor_state_dict": self.actor.state_dict(),
            "actor_optim_state_dict": self.actor_optim.state_dict(),
            "critic_state_dict": self.critic.state_dict(),
            "critic_optim_state_dict": self.critic_optim.state_dict(),
        }
        Learner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        Learner.load_params(self, path)

        params = torch.load(path)
        self.actor.load_state_dict(params["actor_state_dict"])
        self.critic.load_state_dict(params["critic_state_dict"])
        self.actor_optim.load_state_dict(params["actor_optim_state_dict"])
        self.critic_optim.load_state_dict(params["critic_optim_state_dict"])
        print("[INFO] Loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (self.model.state_dict(), self.optim.state_dict())

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection."""
        return self.actor
Exemple #16
0
class SACLearner(Learner):
    """Learner for SAC Agent.

    Attributes:
        args (argparse.Namespace): arguments including hyperparameters and training settings
        hyper_params (ConfigDict): hyper-parameters
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        update_step (int): step number of updates
        target_entropy (int): desired entropy used for the inequality constraint
        log_alpha (torch.Tensor): weight for entropy
        alpha_optim (Optimizer): optimizer for alpha
        actor (nn.Module): actor model to select actions
        actor_optim (Optimizer): optimizer for training actor
        critic_1 (nn.Module): critic model to predict state values
        critic_2 (nn.Module): critic model to predict state values
        critic_target1 (nn.Module): target critic model to predict state values
        critic_target2 (nn.Module): target critic model to predict state values
        critic_optim1 (Optimizer): optimizer for training critic_1
        critic_optim2 (Optimizer): optimizer for training critic_2

    """
    def __init__(
        self,
        args: argparse.Namespace,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
    ):
        Learner.__init__(self, args, env_info, hyper_params, log_cfg)

        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.actor.configs.state_size = (
            self.head_cfg.critic_vf.configs.state_size
        ) = self.env_info.observation_space.shape
        self.head_cfg.critic_qf.configs.state_size = (
            self.env_info.observation_space.shape[0] +
            self.env_info.action_space.shape[0], )
        self.head_cfg.actor.configs.output_size = self.env_info.action_space.shape[
            0]
        self.optim_cfg = optim_cfg

        self.update_step = 0
        if self.hyper_params.auto_entropy_tuning:
            self.target_entropy = -np.prod(
                (self.env_info.action_space.shape[0], )).item()
            self.log_alpha = torch.zeros(1,
                                         requires_grad=True,
                                         device=self.device)
            self.alpha_optim = optim.Adam([self.log_alpha],
                                          lr=optim_cfg.lr_entropy)

        self._init_network()

    # pylint: disable=attribute-defined-outside-init
    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        self.actor = Brain(self.backbone_cfg.actor,
                           self.head_cfg.actor).to(self.device)

        # create v_critic
        self.vf = Brain(self.backbone_cfg.critic_vf,
                        self.head_cfg.critic_vf).to(self.device)
        self.vf_target = Brain(self.backbone_cfg.critic_vf,
                               self.head_cfg.critic_vf).to(self.device)
        self.vf_target.load_state_dict(self.vf.state_dict())

        # create q_critic
        self.qf_1 = Brain(self.backbone_cfg.critic_qf,
                          self.head_cfg.critic_qf).to(self.device)
        self.qf_2 = Brain(self.backbone_cfg.critic_qf,
                          self.head_cfg.critic_qf).to(self.device)

        # create optimizers
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )
        self.vf_optim = optim.Adam(
            self.vf.parameters(),
            lr=self.optim_cfg.lr_vf,
            weight_decay=self.optim_cfg.weight_decay,
        )
        self.qf_1_optim = optim.Adam(
            self.qf_1.parameters(),
            lr=self.optim_cfg.lr_qf1,
            weight_decay=self.optim_cfg.weight_decay,
        )
        self.qf_2_optim = optim.Adam(
            self.qf_2.parameters(),
            lr=self.optim_cfg.lr_qf2,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load the optimizer and model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)

    def update_model(
        self, experience: Union[TensorTuple, Tuple[TensorTuple]]
    ) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]:  # type: ignore
        """Update actor and critic networks."""
        self.update_step += 1

        states, actions, rewards, next_states, dones = experience
        new_actions, log_prob, pre_tanh_value, mu, std = self.actor(states)

        # train alpha
        if self.hyper_params.auto_entropy_tuning:
            alpha_loss = (-self.log_alpha *
                          (log_prob + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            alpha = self.log_alpha.exp()
        else:
            alpha_loss = torch.zeros(1)
            alpha = self.hyper_params.w_entropy

        # Q function loss
        masks = 1 - dones
        states_actions = torch.cat((states, actions), dim=-1)
        q_1_pred = self.qf_1(states_actions)
        q_2_pred = self.qf_2(states_actions)
        v_target = self.vf_target(next_states)
        q_target = rewards + self.hyper_params.gamma * v_target * masks
        qf_1_loss = F.mse_loss(q_1_pred, q_target.detach())
        qf_2_loss = F.mse_loss(q_2_pred, q_target.detach())

        # V function loss
        states_actions = torch.cat((states, new_actions), dim=-1)
        v_pred = self.vf(states)
        q_pred = torch.min(self.qf_1(states_actions),
                           self.qf_2(states_actions))
        v_target = q_pred - alpha * log_prob
        vf_loss = F.mse_loss(v_pred, v_target.detach())

        if self.update_step % self.hyper_params.policy_update_freq == 0:
            # actor loss
            advantage = q_pred - v_pred.detach()
            actor_loss = (alpha * log_prob - advantage).mean()

            # regularization
            mean_reg = self.hyper_params.w_mean_reg * mu.pow(2).mean()
            std_reg = self.hyper_params.w_std_reg * std.pow(2).mean()
            pre_activation_reg = self.hyper_params.w_pre_activation_reg * (
                pre_tanh_value.pow(2).sum(dim=-1).mean())
            actor_reg = mean_reg + std_reg + pre_activation_reg

            # actor loss + regularization
            actor_loss += actor_reg

            # train actor
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # update target networks
            common_utils.soft_update(self.vf, self.vf_target,
                                     self.hyper_params.tau)
        else:
            actor_loss = torch.zeros(1)

        # train Q functions
        self.qf_1_optim.zero_grad()
        qf_1_loss.backward()
        self.qf_1_optim.step()

        self.qf_2_optim.zero_grad()
        qf_2_loss.backward()
        self.qf_2_optim.step()

        # train V function
        self.vf_optim.zero_grad()
        vf_loss.backward()
        self.vf_optim.step()

        return (
            actor_loss.item(),
            qf_1_loss.item(),
            qf_2_loss.item(),
            vf_loss.item(),
            alpha_loss.item(),
        )

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "actor": self.actor.state_dict(),
            "qf_1": self.qf_1.state_dict(),
            "qf_2": self.qf_2.state_dict(),
            "vf": self.vf.state_dict(),
            "vf_target": self.vf_target.state_dict(),
            "actor_optim": self.actor_optim.state_dict(),
            "qf_1_optim": self.qf_1_optim.state_dict(),
            "qf_2_optim": self.qf_2_optim.state_dict(),
            "vf_optim": self.vf_optim.state_dict(),
        }

        if self.hyper_params.auto_entropy_tuning:
            params["alpha_optim"] = self.alpha_optim.state_dict()

        Learner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        Learner.load_params(self, path)

        params = torch.load(path)
        self.actor.load_state_dict(params["actor"])
        self.qf_1.load_state_dict(params["qf_1"])
        self.qf_2.load_state_dict(params["qf_2"])
        self.vf.load_state_dict(params["vf"])
        self.vf_target.load_state_dict(params["vf_target"])
        self.actor_optim.load_state_dict(params["actor_optim"])
        self.qf_1_optim.load_state_dict(params["qf_1_optim"])
        self.qf_2_optim.load_state_dict(params["qf_2_optim"])
        self.vf_optim.load_state_dict(params["vf_optim"])

        if self.hyper_params.auto_entropy_tuning:
            self.alpha_optim.load_state_dict(params["alpha_optim"])

        print("[INFO] loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (self.qf_1.state_dict(), self.qf_2.state_dict(),
                self.actor.state_dict())

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection."""
        return self.actor
Exemple #17
0
class TD3Learner(Learner):
    """Learner for DDPG Agent.

    Attributes:
        hyper_params (ConfigDict): hyper-parameters
        network_cfg (ConfigDict): config of network for training agent
        optim_cfg (ConfigDict): config of optimizer
        noise_cfg (ConfigDict): config of noise
        target_policy_noise (GaussianNoise): random noise for target values
        actor (nn.Module): actor model to select actions
        critic1 (nn.Module): critic model to predict state values
        critic2 (nn.Module): critic model to predict state values
        critic_target1 (nn.Module): target critic model to predict state values
        critic_target2 (nn.Module): target critic model to predict state values
        actor_target (nn.Module): target actor model to select actions
        critic_optim (Optimizer): optimizer for training critic
        actor_optim (Optimizer): optimizer for training actor

    """
    def __init__(
        self,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
        noise_cfg: ConfigDict,
        env_name: str,
        state_size: tuple,
        output_size: int,
        is_test: bool,
        load_from: str,
    ):
        Learner.__init__(self, hyper_params, log_cfg, env_name, is_test)

        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.critic.configs.state_size = (state_size[0] +
                                                   output_size, )
        self.head_cfg.actor.configs.state_size = state_size
        self.head_cfg.actor.configs.output_size = output_size
        self.optim_cfg = optim_cfg
        self.noise_cfg = noise_cfg
        self.load_from = load_from

        self.target_policy_noise = GaussianNoise(
            self.head_cfg.actor.configs.output_size,
            self.noise_cfg.target_policy_noise,
            self.noise_cfg.target_policy_noise,
        )

        self.update_step = 0

        self._init_network()

    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        self.actor = Brain(self.backbone_cfg.actor,
                           self.head_cfg.actor).to(self.device)
        self.actor_target = Brain(self.backbone_cfg.actor,
                                  self.head_cfg.actor).to(self.device)
        self.actor_target.load_state_dict(self.actor.state_dict())

        # create critic
        self.critic1 = Brain(self.backbone_cfg.critic,
                             self.head_cfg.critic).to(self.device)
        self.critic2 = Brain(self.backbone_cfg.critic,
                             self.head_cfg.critic).to(self.device)

        self.critic_target1 = Brain(self.backbone_cfg.critic,
                                    self.head_cfg.critic).to(self.device)
        self.critic_target2 = Brain(self.backbone_cfg.critic,
                                    self.head_cfg.critic).to(self.device)

        self.critic_target1.load_state_dict(self.critic1.state_dict())
        self.critic_target2.load_state_dict(self.critic2.state_dict())

        # concat critic parameters to use one optim
        critic_parameters = list(self.critic1.parameters()) + list(
            self.critic2.parameters())

        # create optimizers
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            critic_parameters,
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load the optimizer and model parameters
        if self.load_from is not None:
            self.load_params(self.load_from)

    def update_model(
            self, experience: Tuple[torch.Tensor,
                                    ...]) -> Tuple[torch.Tensor, ...]:
        """Update TD3 actor and critic networks."""
        self.update_step += 1

        states, actions, rewards, next_states, dones = experience
        masks = 1 - dones

        # get actions with noise
        noise = common_utils.numpy2floattensor(
            self.target_policy_noise.sample(), self.device)
        clipped_noise = torch.clamp(
            noise,
            -self.noise_cfg.target_policy_noise_clip,
            self.noise_cfg.target_policy_noise_clip,
        )
        next_actions = (self.actor_target(next_states) + clipped_noise).clamp(
            -1.0, 1.0)

        # min (Q_1', Q_2')
        next_states_actions = torch.cat((next_states, next_actions), dim=-1)
        next_values1 = self.critic_target1(next_states_actions)
        next_values2 = self.critic_target2(next_states_actions)
        next_values = torch.min(next_values1, next_values2)

        # G_t   = r + gamma * v(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        curr_returns = rewards + self.hyper_params.gamma * next_values * masks
        curr_returns = curr_returns.detach()

        # critic loss
        state_actions = torch.cat((states, actions), dim=-1)
        values1 = self.critic1(state_actions)
        values2 = self.critic2(state_actions)
        critic1_loss = F.mse_loss(values1, curr_returns)
        critic2_loss = F.mse_loss(values2, curr_returns)

        # train critic
        critic_loss = critic1_loss + critic2_loss
        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        if self.update_step % self.hyper_params.policy_update_freq == 0:
            # policy loss
            actions = self.actor(states)
            state_actions = torch.cat((states, actions), dim=-1)
            actor_loss = -self.critic1(state_actions).mean()

            # train actor
            self.actor_optim.zero_grad()
            actor_loss.backward()
            self.actor_optim.step()

            # update target networks
            tau = self.hyper_params.tau
            common_utils.soft_update(self.critic1, self.critic_target1, tau)
            common_utils.soft_update(self.critic2, self.critic_target2, tau)
            common_utils.soft_update(self.actor, self.actor_target, tau)
        else:
            actor_loss = torch.zeros(1)

        return actor_loss.item(), critic1_loss.item(), critic2_loss.item()

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "actor": self.actor.state_dict(),
            "actor_target": self.actor_target.state_dict(),
            "actor_optim": self.actor_optim.state_dict(),
            "critic1": self.critic1.state_dict(),
            "critic2": self.critic2.state_dict(),
            "critic_target1": self.critic_target1.state_dict(),
            "critic_target2": self.critic_target2.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
        }

        Learner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        Learner.load_params(self, path)

        params = torch.load(path)
        self.critic1.load_state_dict(params["critic1"])
        self.critic2.load_state_dict(params["critic2"])
        self.critic_target1.load_state_dict(params["critic_target1"])
        self.critic_target2.load_state_dict(params["critic_target2"])
        self.critic_optim.load_state_dict(params["critic_optim"])
        self.actor.load_state_dict(params["actor"])
        self.actor_target.load_state_dict(params["actor_target"])
        self.actor_optim.load_state_dict(params["actor_optim"])
        print("[INFO] loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (
            self.critic_target1.state_dict(),
            self.critic_target2.state_dict(),
            self.actor.state_dict(),
        )

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection."""
        return self.actor
Exemple #18
0
    def __call__(
        self,
        model: Brain,
        target_model: Brain,
        experiences: Tuple[torch.Tensor, ...],
        gamma: float,
        head_cfg: ConfigDict,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return element-wise IQN loss and Q-values.

        Reference: https://github.com/google/dopamine
        """
        states, actions, rewards, next_states, dones = experiences[:5]
        batch_size = states.shape[0]

        # size of rewards: (n_tau_prime_samples x batch_size) x 1.
        rewards = rewards.repeat(head_cfg.configs.n_tau_prime_samples, 1)

        # size of gamma_with_terminal: (n_tau_prime_samples x batch_size) x 1.
        masks = 1 - dones
        gamma_with_terminal = masks * gamma
        gamma_with_terminal = gamma_with_terminal.repeat(
            head_cfg.configs.n_tau_prime_samples, 1)

        # Get the indices of the maximium Q-value across the action dimension.
        # Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1.
        next_actions = model(next_states).argmax(dim=1)  # double Q
        next_actions = next_actions[:, None]
        next_actions = next_actions.repeat(
            head_cfg.configs.n_tau_prime_samples, 1)

        # Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1.
        target_quantile_values, _ = target_model.forward_(
            next_states, head_cfg.configs.n_tau_prime_samples)
        target_quantile_values = target_quantile_values.gather(1, next_actions)
        target_quantile_values = rewards + gamma_with_terminal * target_quantile_values
        target_quantile_values = target_quantile_values.detach()

        # Reshape to n_tau_prime_samples x batch_size x 1 since this is
        # the manner in which the target_quantile_values are tiled.
        target_quantile_values = target_quantile_values.view(
            head_cfg.configs.n_tau_prime_samples, batch_size, 1)

        # Transpose dimensions so that the dimensionality is batch_size x
        # n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
        target_quantile_values = torch.transpose(target_quantile_values, 0, 1)

        # Get quantile values: (n_tau_samples x batch_size) x action_dim.
        quantile_values, quantiles = model.forward_(
            states, head_cfg.configs.n_tau_samples)
        reshaped_actions = actions[:,
                                   None].repeat(head_cfg.configs.n_tau_samples,
                                                1)
        chosen_action_quantile_values = quantile_values.gather(
            1, reshaped_actions.long())
        chosen_action_quantile_values = chosen_action_quantile_values.view(
            head_cfg.configs.n_tau_samples, batch_size, 1)

        # Transpose dimensions so that the dimensionality is batch_size x
        # n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
        chosen_action_quantile_values = torch.transpose(
            chosen_action_quantile_values, 0, 1)

        # Shape of bellman_erors and huber_loss:
        # batch_size x num_tau_prime_samples x num_tau_samples x 1.
        bellman_errors = (target_quantile_values[:, :, None, :] -
                          chosen_action_quantile_values[:, None, :, :])

        # The huber loss (introduced in QR-DQN) is defined via two cases:
        # case_one: |bellman_errors| <= kappa
        # case_two: |bellman_errors| > kappa
        huber_loss_case_one = (
            (torch.abs(bellman_errors) <= head_cfg.configs.kappa).float() *
            0.5 * bellman_errors**2)
        huber_loss_case_two = (
            (torch.abs(bellman_errors) > head_cfg.configs.kappa).float() *
            head_cfg.configs.kappa *
            (torch.abs(bellman_errors) - 0.5 * head_cfg.configs.kappa))
        huber_loss = huber_loss_case_one + huber_loss_case_two

        # Reshape quantiles to batch_size x num_tau_samples x 1
        quantiles = quantiles.view(head_cfg.configs.n_tau_samples, batch_size,
                                   1)
        quantiles = torch.transpose(quantiles, 0, 1)

        # Tile by num_tau_prime_samples along a new dimension. Shape is now
        # batch_size x num_tau_prime_samples x num_tau_samples x 1.
        # These quantiles will be used for computation of the quantile huber loss
        # below (see section 2.3 of the paper).
        quantiles = quantiles[:, None, :, :].repeat(
            1, head_cfg.configs.n_tau_prime_samples, 1, 1)
        quantiles = quantiles.to(device)

        # Shape: batch_size x n_tau_prime_samples x n_tau_samples x 1.
        quantile_huber_loss = (
            torch.abs(quantiles - (bellman_errors < 0).float().detach()) *
            huber_loss / head_cfg.configs.kappa)

        # Sum over current quantile value (n_tau_samples) dimension,
        # average over target quantile value (n_tau_prime_samples) dimension.
        # Shape: batch_size x n_tau_prime_samples x 1.
        loss = torch.sum(quantile_huber_loss, dim=2)

        # Shape: batch_size x 1.
        iqn_loss_element_wise = torch.mean(loss, dim=1)

        # q values for regularization.
        q_values = model(states)

        return iqn_loss_element_wise, q_values
Exemple #19
0
class GAILPPOLearner(PPOLearner):
    """PPO-based GAILLearner for GAIL Agent.

    Attributes:
        hyper_params (ConfigDict): hyper-parameters
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        actor (nn.Module): actor model to select actions
        critic (nn.Module): critic model to predict state values
        discriminator (nn.Module): discriminator model to classify data
        actor_optim (Optimizer): optimizer for training actor
        critic_optim (Optimizer): optimizer for training critic
        discriminator_optim (Optimizer): optimizer for training discriminator

    """
    def __init__(
        self,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
        env_name: str,
        state_size: tuple,
        output_size: int,
        is_test: bool,
        load_from: str,
    ):
        head.discriminator.configs.state_size = state_size
        head.discriminator.configs.action_size = output_size

        super().__init__(
            hyper_params,
            log_cfg,
            backbone,
            head,
            optim_cfg,
            env_name,
            state_size,
            output_size,
            is_test,
            load_from,
        )

        self.demo_memory = None

    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        if self.backbone_cfg.shared_actor_critic:
            shared_backbone = build_backbone(
                self.backbone_cfg.shared_actor_critic)
            self.actor = Brain(
                self.backbone_cfg.shared_actor_critic,
                self.head_cfg.actor,
                shared_backbone,
            )
            self.critic = Brain(
                self.backbone_cfg.shared_actor_critic,
                self.head_cfg.critic,
                shared_backbone,
            )
            self.actor = self.actor.to(self.device)
            self.critic = self.critic.to(self.device)
        else:
            self.actor = Brain(self.backbone_cfg.actor,
                               self.head_cfg.actor).to(self.device)
            self.critic = Brain(self.backbone_cfg.critic,
                                self.head_cfg.critic).to(self.device)
        self.discriminator = Discriminator(
            self.backbone_cfg.discriminator,
            self.head_cfg.discriminator,
            self.head_cfg.aciton_embedder,
        ).to(self.device)

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.discriminator_optim = optim.Adam(
            self.discriminator.parameters(),
            lr=self.optim_cfg.lr_discriminator,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load model parameters
        if self.load_from is not None:
            self.load_params(self.load_from)

    def update_model(self, experience: TensorTuple,
                     epsilon: float) -> TensorTuple:
        """Update generator(actor), critic and discriminator networks."""
        states, actions, rewards, values, log_probs, next_state, masks = experience
        next_state = numpy2floattensor(next_state, self.device)
        with torch.no_grad():
            next_value = self.critic(next_state)

        returns = ppo_utils.compute_gae(
            next_value,
            rewards,
            masks,
            values,
            self.hyper_params.gamma,
            self.hyper_params.tau,
        )

        states = torch.cat(states)
        actions = torch.cat(actions)
        returns = torch.cat(returns).detach()
        values = torch.cat(values).detach()
        log_probs = torch.cat(log_probs).detach()
        advantages = (returns - values).detach()

        if self.hyper_params.standardize_advantage:
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-7)

        actor_losses, critic_losses, total_losses, discriminator_losses = [], [], [], []

        for (
                state,
                action,
                old_value,
                old_log_prob,
                return_,
                adv,
                epoch,
        ) in ppo_utils.ppo_iter(
                self.hyper_params.epoch,
                self.hyper_params.batch_size,
                states,
                actions,
                values,
                log_probs,
                returns,
                advantages,
        ):

            # critic_loss
            value = self.critic(state)
            if self.hyper_params.use_clipped_value_loss:
                value_pred_clipped = old_value + torch.clamp(
                    (value - old_value), -epsilon, epsilon)
                value_loss_clipped = (return_ - value_pred_clipped).pow(2)
                value_loss = (return_ - value).pow(2)
                critic_loss = 0.5 * torch.max(value_loss,
                                              value_loss_clipped).mean()
            else:
                critic_loss = 0.5 * (return_ - value).pow(2).mean()
            critic_loss_ = self.hyper_params.w_value * critic_loss

            # train critic
            self.critic_optim.zero_grad()
            critic_loss_.backward()
            clip_grad_norm_(self.critic.parameters(),
                            self.hyper_params.gradient_clip_cr)
            self.critic_optim.step()

            # calculate ratios
            _, dist = self.actor(state)
            log_prob = dist.log_prob(action)
            ratio = (log_prob - old_log_prob).exp()

            # actor_loss
            surr_loss = ratio * adv
            clipped_surr_loss = torch.clamp(ratio, 1.0 - epsilon,
                                            1.0 + epsilon) * adv
            actor_loss = -torch.min(surr_loss, clipped_surr_loss).mean()

            # entropy
            entropy = dist.entropy().mean()
            actor_loss_ = actor_loss - self.hyper_params.w_entropy * entropy

            # train actor
            self.actor_optim.zero_grad()
            actor_loss_.backward()
            clip_grad_norm_(self.actor.parameters(),
                            self.hyper_params.gradient_clip_ac)
            self.actor_optim.step()

            # total_loss
            total_loss = critic_loss_ + actor_loss_

            # discriminator loss
            demo_state, demo_action = self.demo_memory.sample(len(state))
            exp_score = torch.sigmoid(
                self.discriminator.forward((state, action)))
            demo_score = torch.sigmoid(
                self.discriminator.forward((demo_state, demo_action)))
            discriminator_exp_acc = (exp_score > 0.5).float().mean().item()
            discriminator_demo_acc = (demo_score <= 0.5).float().mean().item()
            discriminator_loss = F.binary_cross_entropy(
                exp_score,
                torch.ones_like(exp_score)) + F.binary_cross_entropy(
                    demo_score, torch.zeros_like(demo_score))

            # train discriminator
            if (discriminator_exp_acc <
                    self.optim_cfg.discriminator_acc_threshold
                    or discriminator_demo_acc <
                    self.optim_cfg.discriminator_acc_threshold and epoch == 0):
                self.discriminator_optim.zero_grad()
                discriminator_loss.backward()
                self.discriminator_optim.step()

            actor_losses.append(actor_loss.item())
            critic_losses.append(critic_loss.item())
            total_losses.append(total_loss.item())
            discriminator_losses.append(discriminator_loss.item())

        actor_loss = sum(actor_losses) / len(actor_losses)
        critic_loss = sum(critic_losses) / len(critic_losses)
        total_loss = sum(total_losses) / len(total_losses)
        discriminator_loss = sum(discriminator_losses) / len(
            discriminator_losses)

        return (
            (actor_loss, critic_loss, total_loss, discriminator_loss),
            (discriminator_exp_acc, discriminator_demo_acc),
        )

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "actor_state_dict":
            self.actor.state_dict(),
            "critic_state_dict":
            self.critic.state_dict(),
            "discriminator_state_dict":
            self.discriminator.state_dict(),
            "actor_optim_state_dict":
            self.actor_optim.state_dict(),
            "critic_optim_state_dict":
            self.critic_optim.state_dict(),
            "discriminator_optim_state_dict":
            self.discriminator_optim.state_dict(),
        }
        PPOLearner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        PPOLearner.load_params(self, path)

        params = torch.load(path)
        self.actor.load_state_dict(params["actor_state_dict"])
        self.critic.load_state_dict(params["critic_state_dict"])
        self.actor_optim.load_state_dict(params["actor_optim_state_dict"])
        self.critic_optim.load_state_dict(params["critic_optim_state_dict"])
        print("[INFO] loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (
            self.actor.state_dict(),
            self.critic.state_dict(),
            self.discriminator.state_dict(),
        )

    def set_demo_memory(self, demo_memory):
        self.demo_memory = demo_memory
Exemple #20
0
    def __call__(
        self,
        model: Brain,
        target_model: Brain,
        experiences: Tuple[torch.Tensor, ...],
        gamma: float,
        head_cfg: ConfigDict,
        burn_in_step: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return R2D1 loss and Q-values."""
        # TODO: Combine with IQNLoss
        output_size = head_cfg.configs.output_size
        (
            burnin_states_tuple,
            states_tuple,
            burnin_prev_actions_tuple,
            agent_actions,
            prev_actions_tuple,
            burnin_prev_rewards_tuple,
            agent_rewards,
            prev_rewards_tuple,
            burnin_dones_tuple,
            agent_dones,
            init_rnn_state,
        ) = slice_r2d1_arguments(experiences, burn_in_step, output_size)

        batch_size = states_tuple[0].shape[0]
        sequence_size = states_tuple[0].shape[1]

        with torch.no_grad():
            _, target_rnn_state = target_model(
                burnin_states_tuple[1],
                init_rnn_state,
                burnin_prev_actions_tuple[1],
                burnin_prev_rewards_tuple[1],
            )
            _, init_rnn_state = model(
                burnin_states_tuple[0],
                init_rnn_state,
                burnin_prev_actions_tuple[0],
                burnin_prev_rewards_tuple[0],
            )

            init_rnn_state = torch.transpose(init_rnn_state, 0, 1)
            target_rnn_state = torch.transpose(target_rnn_state, 0, 1)

        burnin_invalid_mask = valid_from_done(burnin_dones_tuple[0].transpose(
            0, 1))
        burnin_target_invalid_mask = valid_from_done(
            burnin_dones_tuple[1].transpose(0, 1))
        init_rnn_state[burnin_invalid_mask] = 0
        target_rnn_state[burnin_target_invalid_mask] = 0

        # size of rewards: (n_tau_prime_samples x batch_size) x 1.
        agent_rewards = agent_rewards.repeat(
            head_cfg.configs.n_tau_prime_samples, 1, 1)

        # size of gamma_with_terminal: (n_tau_prime_samples x batch_size) x 1.
        masks = 1 - agent_dones
        gamma_with_terminal = masks * gamma
        gamma_with_terminal = gamma_with_terminal.repeat(
            head_cfg.configs.n_tau_prime_samples, 1, 1)
        # Get the indices of the maximium Q-value across the action dimension.
        # Shape of replay_next_qt_argmax: (n_tau_prime_samples x batch_size) x 1.
        next_actions, _ = model(
            states_tuple[1],
            target_rnn_state,
            prev_actions_tuple[1],
            prev_rewards_tuple[1],
        ).argmax(dim=-1)
        next_actions = next_actions[:, :, None]
        next_actions = next_actions.repeat(
            head_cfg.configs.n_tau_prime_samples, 1, 1)

        with torch.no_grad():
            # Shape of next_target_values: (n_tau_prime_samples x batch_size) x 1.
            target_quantile_values, _, _ = target_model.forward_(
                states_tuple[1],
                target_rnn_state,
                prev_actions_tuple[1],
                prev_rewards_tuple[1],
                head_cfg.configs.n_tau_prime_samples,
            )
            target_quantile_values = target_quantile_values.gather(
                -1, next_actions)
            target_quantile_values = (
                agent_rewards + gamma_with_terminal * target_quantile_values)
            target_quantile_values = target_quantile_values.detach()

            # Reshape to n_tau_prime_samples x batch_size x 1 since this is
            # the manner in which the target_quantile_values are tiled.
            target_quantile_values = target_quantile_values.view(
                head_cfg.configs.n_tau_prime_samples, batch_size,
                sequence_size, 1)

            # Transpose dimensions so that the dimensionality is batch_size x
            # n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
            target_quantile_values = torch.transpose(target_quantile_values, 0,
                                                     1)

        # Get quantile values: (n_tau_samples x batch_size) x action_dim.
        quantile_values, quantiles, _ = model.forward_(
            states_tuple[0],
            init_rnn_state,
            prev_actions_tuple[0],
            prev_rewards_tuple[0],
            head_cfg.configs.n_tau_samples,
        )

        reshaped_actions = agent_actions.repeat(head_cfg.configs.n_tau_samples,
                                                1, 1)
        chosen_action_quantile_values = quantile_values.gather(
            -1, reshaped_actions.long())
        chosen_action_quantile_values = chosen_action_quantile_values.view(
            head_cfg.configs.n_tau_samples, batch_size, sequence_size, 1)

        # Transpose dimensions so that the dimensionality is batch_size x
        # n_tau_prime_samples x 1 to prepare for computation of Bellman errors.
        chosen_action_quantile_values = torch.transpose(
            chosen_action_quantile_values, 0, 1)

        # Shape of bellman_erors and huber_loss:
        # batch_size x num_tau_prime_samples x num_tau_samples x 1.
        bellman_errors = (target_quantile_values[:, :, None, :] -
                          chosen_action_quantile_values[:, None, :, :])

        # The huber loss (introduced in QR-DQN) is defined via two cases:
        # case_one: |bellman_errors| <= kappa
        # case_two: |bellman_errors| > kappa
        huber_loss_case_one = (
            (torch.abs(bellman_errors) <= head_cfg.configs.kappa).float() *
            0.5 * bellman_errors**2)
        huber_loss_case_two = (
            (torch.abs(bellman_errors) > head_cfg.configs.kappa).float() *
            head_cfg.configs.kappa *
            (torch.abs(bellman_errors) - 0.5 * head_cfg.configs.kappa))
        huber_loss = huber_loss_case_one + huber_loss_case_two

        # Reshape quantiles to batch_size x num_tau_samples x 1
        quantiles = quantiles.view(head_cfg.configs.n_tau_samples, batch_size,
                                   sequence_size, 1)
        quantiles = torch.transpose(quantiles, 0, 1)

        # Tile by num_tau_prime_samples along a new dimension. Shape is now
        # batch_size x num_tau_prime_samples x num_tau_samples x sequence_length x 1.
        # These quantiles will be used for computation of the quantile huber loss
        # below (see section 2.3 of the paper).
        quantiles = quantiles[:, None, :, :, :].repeat(
            1, head_cfg.configs.n_tau_prime_samples, 1, 1, 1)

        # Shape: batch_size x n_tau_prime_samples x n_tau_samples x sequence_length x 1.
        quantile_huber_loss = (
            torch.abs(quantiles - (bellman_errors < 0).float().detach()) *
            huber_loss / head_cfg.configs.kappa)

        # Sum over current quantile value (n_tau_samples) dimension,
        # average over target quantile value (n_tau_prime_samples) dimension.
        # Shape: batch_size x n_tau_prime_samples x 1.
        loss = torch.sum(quantile_huber_loss, dim=2)

        # Shape: batch_size x sequence_length x 1.
        iqn_loss_element_wise = torch.mean(loss, dim=1)

        # Shape: batch_size x 1.
        iqn_loss_element_wise = abs(torch.mean(iqn_loss_element_wise, dim=1))

        # q values for regularization.
        q_values, _ = model(
            states_tuple[0],
            init_rnn_state,
            prev_actions_tuple[0],
            prev_rewards_tuple[0],
        )

        return iqn_loss_element_wise, q_values
Exemple #21
0
class A2CLearner(Learner):
    """Learner for A2C Agent.

    Attributes:
        hyper_params (ConfigDict): hyper-parameters
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        actor (nn.Module): actor model to select actions
        critic (nn.Module): critic model to predict state values
        actor_optim (Optimizer): optimizer for training actor
        critic_optim (Optimizer): optimizer for training critic

    """

    def __init__(
        self,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
        env_name: str,
        state_size: tuple,
        output_size: int,
        is_test: bool,
        load_from: str,
    ):
        Learner.__init__(self, hyper_params, log_cfg, env_name, is_test)

        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.actor.configs.state_size = (
            self.head_cfg.critic.configs.state_size
        ) = state_size
        self.head_cfg.actor.configs.output_size = output_size
        self.optim_cfg = optim_cfg
        self.load_from = load_from

        self._init_network()

    def _init_network(self):
        """Initialize networks and optimizers."""
        self.actor = Brain(self.backbone_cfg.actor, self.head_cfg.actor).to(self.device)
        self.critic = Brain(self.backbone_cfg.critic, self.head_cfg.critic).to(
            self.device
        )

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        if self.load_from is not None:
            self.load_params(self.load_from)

    def update_model(self, experience: TensorTuple) -> TensorTuple:
        """Update A2C actor and critic networks"""

        log_prob, pred_value, next_state, reward, done = experience
        next_state = numpy2floattensor(next_state, self.device)

        # Q_t   = r + gamma * V(s_{t+1})  if state != Terminal
        #       = r                       otherwise
        mask = 1 - done
        next_value = self.critic(next_state).detach()
        q_value = reward + self.hyper_params.gamma * next_value * mask
        q_value = q_value.to(self.device)

        # advantage = Q_t - V(s_t)
        advantage = q_value - pred_value

        # calculate loss at the current step
        policy_loss = -advantage.detach() * log_prob  # adv. is not backpropagated
        policy_loss += self.hyper_params.w_entropy * -log_prob  # entropy
        value_loss = F.smooth_l1_loss(pred_value, q_value.detach())

        # train
        gradient_clip_ac = self.hyper_params.gradient_clip_ac
        gradient_clip_cr = self.hyper_params.gradient_clip_cr

        self.actor_optim.zero_grad()
        policy_loss.backward()
        clip_grad_norm_(self.actor.parameters(), gradient_clip_ac)
        self.actor_optim.step()

        self.critic_optim.zero_grad()
        value_loss.backward()
        clip_grad_norm_(self.critic.parameters(), gradient_clip_cr)
        self.critic_optim.step()

        return policy_loss.item(), value_loss.item()

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "actor_state_dict": self.actor.state_dict(),
            "critic_state_dict": self.critic.state_dict(),
            "actor_optim_state_dict": self.actor_optim.state_dict(),
            "critic_optim_state_dict": self.critic_optim.state_dict(),
        }

        Learner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        Learner.load_params(self, path)

        params = torch.load(path)
        self.actor.load_state_dict(params["actor_state_dict"])
        self.critic.load_state_dict(params["critic_state_dict"])
        self.actor_optim.load_state_dict(params["actor_optim_state_dict"])
        self.critic_optim.load_state_dict(params["critic_optim_state_dict"])
        print("[INFO] Loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (self.critic.state_dict(), self.actor.state_dict())

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection."""
        return self.actor
Exemple #22
0
class PPOLearner(Learner):
    """Learner for PPO Agent.

    Attributes:
        args (argparse.Namespace): arguments including hyperparameters and training settings
        hyper_params (ConfigDict): hyper-parameters
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        actor (nn.Module): actor model to select actions
        critic (nn.Module): critic model to predict state values
        actor_optim (Optimizer): optimizer for training actor
        critic_optim (Optimizer): optimizer for training critic

    """
    def __init__(
        self,
        args: argparse.Namespace,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
    ):
        Learner.__init__(
            self,
            args,
            env_info,
            hyper_params,
            log_cfg,
        )

        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.actor.configs.state_size = (
            self.head_cfg.critic.configs.state_size
        ) = self.env_info.observation_space.shape
        self.head_cfg.actor.configs.output_size = self.env_info.action_space.shape[
            0]
        self.optim_cfg = optim_cfg
        self.is_discrete = self.hyper_params.is_discrete

        self._init_network()

    def _init_network(self):
        """Initialize networks and optimizers."""
        # create actor
        self.actor = Brain(self.backbone_cfg.actor,
                           self.head_cfg.actor).to(self.device)
        self.critic = Brain(self.backbone_cfg.critic,
                            self.head_cfg.critic).to(self.device)

        # create optimizer
        self.actor_optim = optim.Adam(
            self.actor.parameters(),
            lr=self.optim_cfg.lr_actor,
            weight_decay=self.optim_cfg.weight_decay,
        )

        self.critic_optim = optim.Adam(
            self.critic.parameters(),
            lr=self.optim_cfg.lr_critic,
            weight_decay=self.optim_cfg.weight_decay,
        )

        # load model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)

    def update_model(self, experience: TensorTuple,
                     epsilon: float) -> TensorTuple:
        """Update PPO actor and critic networks"""
        states, actions, rewards, values, log_probs, next_state, masks = experience
        next_state = numpy2floattensor(next_state, self.device)
        next_value = self.critic(next_state)

        returns = ppo_utils.compute_gae(
            next_value,
            rewards,
            masks,
            values,
            self.hyper_params.gamma,
            self.hyper_params.tau,
        )

        states = torch.cat(states)
        actions = torch.cat(actions)
        returns = torch.cat(returns).detach()
        values = torch.cat(values).detach()
        log_probs = torch.cat(log_probs).detach()
        advantages = returns - values

        if self.is_discrete:
            actions = actions.unsqueeze(1)
            log_probs = log_probs.unsqueeze(1)

        if self.hyper_params.standardize_advantage:
            advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                             1e-7)

        actor_losses, critic_losses, total_losses = [], [], []

        for state, action, old_value, old_log_prob, return_, adv in ppo_utils.ppo_iter(
                self.hyper_params.epoch,
                self.hyper_params.batch_size,
                states,
                actions,
                values,
                log_probs,
                returns,
                advantages,
        ):
            # calculate ratios
            _, dist = self.actor(state)
            log_prob = dist.log_prob(action)
            ratio = (log_prob - old_log_prob).exp()

            # actor_loss
            surr_loss = ratio * adv
            clipped_surr_loss = torch.clamp(ratio, 1.0 - epsilon,
                                            1.0 + epsilon) * adv
            actor_loss = -torch.min(surr_loss, clipped_surr_loss).mean()

            # critic_loss
            value = self.critic(state)
            if self.hyper_params.use_clipped_value_loss:
                value_pred_clipped = old_value + torch.clamp(
                    (value - old_value), -epsilon, epsilon)
                value_loss_clipped = (return_ - value_pred_clipped).pow(2)
                value_loss = (return_ - value).pow(2)
                critic_loss = 0.5 * torch.max(value_loss,
                                              value_loss_clipped).mean()
            else:
                critic_loss = 0.5 * (return_ - value).pow(2).mean()

            # entropy
            entropy = dist.entropy().mean()

            # total_loss
            w_value = self.hyper_params.w_value
            w_entropy = self.hyper_params.w_entropy

            critic_loss_ = w_value * critic_loss
            actor_loss_ = actor_loss - w_entropy * entropy
            total_loss = critic_loss_ + actor_loss_

            # train critic
            gradient_clip_ac = self.hyper_params.gradient_clip_ac
            gradient_clip_cr = self.hyper_params.gradient_clip_cr

            self.critic_optim.zero_grad()
            critic_loss_.backward(retain_graph=True)
            clip_grad_norm_(self.critic.parameters(), gradient_clip_ac)
            self.critic_optim.step()

            # train actor
            self.actor_optim.zero_grad()
            actor_loss_.backward()
            clip_grad_norm_(self.actor.parameters(), gradient_clip_cr)
            self.actor_optim.step()

            actor_losses.append(actor_loss.item())
            critic_losses.append(critic_loss.item())
            total_losses.append(total_loss.item())

        actor_loss = sum(actor_losses) / len(actor_losses)
        critic_loss = sum(critic_losses) / len(critic_losses)
        total_loss = sum(total_losses) / len(total_losses)

        return actor_loss, critic_loss, total_loss

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "actor_state_dict": self.actor.state_dict(),
            "critic_state_dict": self.critic.state_dict(),
            "actor_optim_state_dict": self.actor_optim.state_dict(),
            "critic_optim_state_dict": self.critic_optim.state_dict(),
        }
        Learner._save_params(self, params, n_episode)

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        Learner.load_params(self, path)

        params = torch.load(path)
        self.actor.load_state_dict(params["actor_state_dict"])
        self.critic.load_state_dict(params["critic_state_dict"])
        self.actor_optim.load_state_dict(params["actor_optim_state_dict"])
        self.critic_optim.load_state_dict(params["critic_optim_state_dict"])
        print("[INFO] loaded the model and optimizer from", path)

    def get_state_dict(self) -> Tuple[OrderedDict]:
        """Return state dicts, mainly for distributed worker."""
        return (self.actor.state_dict(), self.critic.state_dict())

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection."""
        return self.actor
Exemple #23
0
class DQNLearner(Learner):
    """Learner for DQN Agent.

    Attributes:
        args (argparse.Namespace): arguments including hyperparameters and training settings
        hyper_params (ConfigDict): hyper-parameters
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        dqn (nn.Module): dqn model to predict state Q values
        dqn_target (nn.Module): target dqn model to predict state Q values
        dqn_optim (Optimizer): optimizer for training dqn

    """
    def __init__(
        self,
        args: argparse.Namespace,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        log_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        optim_cfg: ConfigDict,
        loss_type: ConfigDict,
    ):
        Learner.__init__(self, args, env_info, hyper_params, log_cfg)
        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.configs.state_size = self.env_info.observation_space.shape
        self.head_cfg.configs.output_size = self.env_info.action_space.n
        self.optim_cfg = optim_cfg
        self.use_n_step = self.hyper_params.n_step > 1
        self.loss_type = loss_type
        self._init_network()

    # pylint: disable=attribute-defined-outside-init
    def _init_network(self):
        """Initialize networks and optimizers."""
        self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device)
        self.dqn_target = Brain(self.backbone_cfg,
                                self.head_cfg).to(self.device)
        self.loss_fn = build_loss(self.loss_type)

        self.dqn_target.load_state_dict(self.dqn.state_dict())

        # create optimizer
        self.dqn_optim = optim.Adam(
            self.dqn.parameters(),
            lr=self.optim_cfg.lr_dqn,
            weight_decay=self.optim_cfg.weight_decay,
            eps=self.optim_cfg.adam_eps,
        )

        # load the optimizer and model parameters
        if self.args.load_from is not None:
            self.load_params(self.args.load_from)

    def update_model(
        self, experience: Union[TensorTuple, Tuple[TensorTuple]]
    ) -> Tuple[torch.Tensor, torch.Tensor, list, np.ndarray]:  # type: ignore
        """Update dqn and dqn target."""

        if self.use_n_step:
            experience_1, experience_n = experience
        else:
            experience_1 = experience

        weights, indices = experience_1[-3:-1]

        gamma = self.hyper_params.gamma

        dq_loss_element_wise, q_values = self.loss_fn(self.dqn,
                                                      self.dqn_target,
                                                      experience_1, gamma,
                                                      self.head_cfg)

        dq_loss = torch.mean(dq_loss_element_wise * weights)

        # n step loss
        if self.use_n_step:
            gamma = self.hyper_params.gamma**self.hyper_params.n_step

            dq_loss_n_element_wise, q_values_n = self.loss_fn(
                self.dqn, self.dqn_target, experience_n, gamma, self.head_cfg)

            # to update loss and priorities
            q_values = 0.5 * (q_values + q_values_n)
            dq_loss_element_wise += dq_loss_n_element_wise * self.hyper_params.w_n_step
            dq_loss = torch.mean(dq_loss_element_wise * weights)

        # q_value regularization
        q_regular = torch.norm(q_values, 2).mean() * self.hyper_params.w_q_reg

        # total loss
        loss = dq_loss + q_regular

        self.dqn_optim.zero_grad()
        loss.backward()
        clip_grad_norm_(self.dqn.parameters(), self.hyper_params.gradient_clip)
        self.dqn_optim.step()

        # update target networks
        common_utils.soft_update(self.dqn, self.dqn_target,
                                 self.hyper_params.tau)

        # update priorities in PER
        loss_for_prior = dq_loss_element_wise.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.hyper_params.per_eps

        if self.head_cfg.configs.use_noisy_net:
            self.dqn.head.reset_noise()
            self.dqn_target.head.reset_noise()

        return (
            loss.item(),
            q_values.mean().item(),
            indices,
            new_priorities,
        )

    def save_params(self, n_episode: int):
        """Save model and optimizer parameters."""
        params = {
            "dqn_state_dict": self.dqn.state_dict(),
            "dqn_target_state_dict": self.dqn_target.state_dict(),
            "dqn_optim_state_dict": self.dqn_optim.state_dict(),
        }

        Learner._save_params(self, params, n_episode)

    # pylint: disable=attribute-defined-outside-init
    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        Learner.load_params(self, path)

        params = torch.load(path)
        self.dqn.load_state_dict(params["dqn_state_dict"])
        self.dqn_target.load_state_dict(params["dqn_target_state_dict"])
        self.dqn_optim.load_state_dict(params["dqn_optim_state_dict"])
        print("[INFO] loaded the model and optimizer from", path)

    def get_state_dict(self) -> OrderedDict:
        """Return state dicts, mainly for distributed worker."""
        dqn = deepcopy(self.dqn)
        return dqn.cpu().state_dict()

    def get_policy(self) -> nn.Module:
        """Return model (policy) used for action selection, used only in grad cam."""
        return self.dqn
Exemple #24
0
 def _init_networks(self, state_dict: OrderedDict):
     """Initialize DQN policy with learner state dict."""
     self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device)
     self.dqn.load_state_dict(state_dict)
     self.dqn.eval()
Exemple #25
0
    def __call__(
        self,
        model: Brain,
        target_model: Brain,
        experiences: Tuple[torch.Tensor, ...],
        gamma: float,
        head_cfg: ConfigDict,
        burn_in_step: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Return element-wise C51 loss and Q-values."""
        # TODO: Combine with IQNLoss
        output_size = head_cfg.configs.output_size
        (
            burnin_states_tuple,
            states_tuple,
            burnin_prev_actions_tuple,
            agent_actions,
            prev_actions_tuple,
            burnin_prev_rewards_tuple,
            agent_rewards,
            prev_rewards_tuple,
            burnin_dones_tuple,
            agent_dones,
            init_rnn_state,
        ) = slice_r2d1_arguments(experiences, burn_in_step, output_size)

        batch_size = states_tuple[0].shape[0]
        sequence_size = states_tuple[0].shape[1]

        with torch.no_grad():
            _, target_rnn_state = target_model(
                burnin_states_tuple[1],
                init_rnn_state,
                burnin_prev_actions_tuple[1],
                burnin_prev_rewards_tuple[1],
            )
            _, init_rnn_state = model(
                burnin_states_tuple[0],
                init_rnn_state,
                burnin_prev_actions_tuple[0],
                burnin_prev_rewards_tuple[0],
            )

            init_rnn_state = torch.transpose(init_rnn_state, 0, 1)
            target_rnn_state = torch.transpose(target_rnn_state, 0, 1)

        burnin_invalid_mask = valid_from_done(burnin_dones_tuple[0].transpose(
            0, 1))
        burnin_target_invalid_mask = valid_from_done(
            burnin_dones_tuple[1].transpose(0, 1))
        init_rnn_state[burnin_invalid_mask] = 0
        target_rnn_state[burnin_target_invalid_mask] = 0

        support = torch.linspace(head_cfg.configs.v_min,
                                 head_cfg.configs.v_max,
                                 head_cfg.configs.atom_size).to(device)
        delta_z = float(head_cfg.configs.v_max - head_cfg.configs.v_min) / (
            head_cfg.configs.atom_size - 1)

        with torch.no_grad():
            # According to noisynet paper,
            # it resamples noisynet parameters on online network when using double q
            # but we don't because there is no remarkable difference in performance.
            next_actions, _ = model.forward_(
                states_tuple[1],
                target_rnn_state,
                prev_actions_tuple[1],
                prev_rewards_tuple[1],
            )
            next_actions = next_actions[1].argmax(-1)
            next_dist, _ = target_model.forward_(
                states_tuple[1],
                target_rnn_state,
                prev_actions_tuple[1],
                prev_rewards_tuple[1],
            )
            next_dist = next_dist[0][range(batch_size * sequence_size),
                                     next_actions]

            t_z = agent_rewards + (1 - agent_dones) * gamma * support
            t_z = t_z.clamp(min=head_cfg.configs.v_min,
                            max=head_cfg.configs.v_max)
            b = (t_z - head_cfg.configs.v_min) / delta_z
            b = b.view(batch_size * sequence_size, -1)
            l = b.floor().long()  # noqa: E741
            u = b.ceil().long()

            offset = (torch.linspace(
                0,
                (batch_size * sequence_size - 1) * head_cfg.configs.atom_size,
                batch_size * sequence_size,
            ).long().unsqueeze(1))
            offset = offset.expand(batch_size * sequence_size,
                                   head_cfg.configs.atom_size).to(device)
            proj_dist = torch.zeros(next_dist.size(), device=device)
            proj_dist.view(-1).index_add_(0, (l + offset).view(-1),
                                          (next_dist *
                                           (u.float() - b)).view(-1))
            proj_dist.view(-1).index_add_(0, (u + offset).view(-1),
                                          (next_dist *
                                           (b - l.float())).view(-1))

        (dist, q_values), _ = model.forward_(
            states_tuple[0],
            init_rnn_state,
            prev_actions_tuple[0],
            prev_rewards_tuple[0],
        )
        log_p = dist[range(batch_size * sequence_size),
                     agent_actions.contiguous().view(batch_size *
                                                     sequence_size).long(), ]
        log_p = torch.log(log_p.clamp(min=1e-5))
        log_p = log_p.view(batch_size, sequence_size, -1)
        proj_dist = proj_dist.view(batch_size, sequence_size, -1)
        dq_loss_element_wise = -(proj_dist * log_p).sum(-1).mean(1)

        return dq_loss_element_wise, q_values