예제 #1
0
class DistributedLogger(ABC):
    """Base class for loggers use in distributed training.

    Attributes:
        log_cfg (ConfigDict): configuration for saving log and checkpoint
        comm_config (ConfigDict): configs for communication
        backbone (ConfigDict): backbone configs for building network
        head (ConfigDict): head configs for building network
        brain (Brain): logger brain for evaluation
        update_step (int): tracker for learner update step
        device (torch.device): device, cpu by default
        log_info_queue (deque): queue for storing log info received from learner
        env (gym.Env): gym environment for running test

    """
    def __init__(
        self,
        log_cfg: ConfigDict,
        comm_cfg: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        env_name: str,
        is_atari: bool,
        state_size: int,
        output_size: int,
        max_update_step: int,
        episode_num: int,
        max_episode_steps: int,
        interim_test_num: int,
        is_log: bool,
        is_render: bool,
    ):
        self.log_cfg = log_cfg
        self.comm_cfg = comm_cfg
        self.device = torch.device("cpu")  # Logger only runs on cpu
        head.configs.state_size = state_size
        head.configs.output_size = output_size
        self.brain = Brain(backbone, head).to(self.device)

        self.env_name = env_name
        self.is_atari = is_atari
        self.max_update_step = max_update_step
        self.episode_num = episode_num
        self.max_episode_steps = max_episode_steps
        self.interim_test_num = interim_test_num
        self.is_log = is_log
        self.is_render = is_render

        self.update_step = 0
        self.log_info_queue = deque(maxlen=100)

        self._init_env()

    # pylint: disable=attribute-defined-outside-init
    def _init_env(self):
        """Initialize gym environment."""
        if self.is_atari:
            self.env = atari_env_generator(self.env_name,
                                           self.max_episode_steps)
        else:
            self.env = gym.make(self.env_name)
            self.env, self.max_episode_steps = env_utils.set_env(
                self.env, self.max_episode_steps)

    @abstractmethod
    def load_params(self, path: str):
        if not os.path.exists(path):
            raise Exception(
                f"[ERROR] the input path does not exist. Wrong path: {path}")

    # pylint: disable=attribute-defined-outside-init
    def init_communication(self):
        """Initialize inter-process communication sockets."""
        ctx = zmq.Context()
        self.pull_socket = ctx.socket(zmq.PULL)
        self.pull_socket.bind(
            f"tcp://127.0.0.1:{self.comm_cfg.learner_logger_port}")

    @abstractmethod
    def select_action(self, state: np.ndarray):
        pass

    @abstractmethod
    def write_log(self, log_value: dict):
        pass

    # pylint: disable=no-self-use
    @staticmethod
    def _preprocess_state(state: np.ndarray,
                          device: torch.device) -> torch.Tensor:
        state = numpy2floattensor(state, device)
        return state

    def set_wandb(self):
        """Set configuration for wandb logging."""
        wandb.init(
            project=self.env_name,
            name=f"{self.log_cfg.agent}/{self.log_cfg.curr_time}",
        )
        additional_log = dict(
            episode_num=self.episode_num,
            max_episode_steps=self.max_episode_steps,
        )
        wandb.config.update(additional_log)
        shutil.copy(self.log_cfg.cfg_path,
                    os.path.join(wandb.run.dir, "config.yaml"))

    def recv_log_info(self):
        """Receive info from learner."""
        received = False
        try:
            log_info_id = self.pull_socket.recv(zmq.DONTWAIT)
            received = True
        except zmq.Again:
            pass

        if received:
            self.log_info_queue.append(log_info_id)

    def run(self):
        """Run main logging loop; continuously receive data and log."""
        if self.is_log:
            self.set_wandb()

        while self.update_step < self.max_update_step:
            self.recv_log_info()
            if self.log_info_queue:  # if non-empty
                log_info_id = self.log_info_queue.pop()
                log_info = pa.deserialize(log_info_id)
                state_dict = log_info["state_dict"]
                log_value = log_info["log_value"]
                self.update_step = log_value["update_step"]

                self.synchronize(state_dict)
                avg_score = self.test(self.update_step)
                log_value["avg_score"] = avg_score
                self.write_log(log_value)

    def write_worker_log(self, worker_logs: List[dict],
                         worker_update_interval: int):
        """Log the mean scores of each episode per update step to wandb."""
        # NOTE: Worker plots are passed onto wandb.log as matplotlib.pyplot
        #       since wandb doesn't support logging multiple lines to single plot
        self.set_wandb()
        # Plot individual workers
        fig = go.Figure()
        worker_id = 0
        for worker_log in worker_logs:
            fig.add_trace(
                go.Scatter(
                    x=list(worker_log.keys()),
                    y=smoothen_graph(list(worker_log.values())),
                    mode="lines",
                    name=f"Worker {worker_id}",
                    line=dict(width=2),
                ))
            worker_id = worker_id + 1

        # Plot mean scores
        logged_update_steps = list(
            range(0, self.max_update_step + 1, worker_update_interval))

        mean_scores = []
        try:
            for step in logged_update_steps:
                scores_for_step = []
                for worker_log in worker_logs:
                    if step in list(worker_log):
                        scores_for_step.append(worker_log[step])
                mean_scores.append(np.mean(scores_for_step))
        except Exception as e:
            print(f"[Error] {e}")

        fig.add_trace(
            go.Scatter(
                x=logged_update_steps,
                y=mean_scores,
                mode="lines+markers",
                name="Mean scores",
                line=dict(width=5),
            ))

        # Write to wandb
        wandb.log({"Worker scores": fig})

    def test(self, update_step: int, interim_test: bool = True):
        """Test the agent."""
        avg_score = self._test(update_step, interim_test)

        # termination
        self.env.close()
        return avg_score

    def _test(self, update_step: int, interim_test: bool) -> float:
        """Common test routine."""
        if interim_test:
            test_num = self.interim_test_num
        else:
            test_num = self.episode_num

        self.brain.eval()
        scores = []
        for i_episode in range(test_num):
            state = self.env.reset()
            done = False
            score = 0
            step = 0

            while not done:
                if self.is_render:
                    self.env.render()

                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action)

                state = next_state
                score += reward
                step += 1

            scores.append(score)

            if interim_test:
                print(
                    "[INFO] update step: %d\ttest %d\tstep: %d\ttotal score: %d"
                    % (update_step, i_episode, step, score))
            else:
                print("[INFO] test %d\tstep: %d\ttotal score: %d" %
                      (i_episode, step, score))

        return np.mean(scores)

    def synchronize(self, state_dict: Dict[str, np.ndarray]):
        """Copy parameters from numpy arrays."""
        param_name_list = list(state_dict.keys())
        for logger_named_param in self.brain.named_parameters():
            logger_param_name = logger_named_param[0]
            if logger_param_name in param_name_list:
                new_param = numpy2floattensor(state_dict[logger_param_name],
                                              self.device)
                logger_named_param[1].data.copy_(new_param)
예제 #2
0
class DQNWorker(DistributedWorker):
    """DQN worker for distributed training.

    Attributes:
        backbone (ConfigDict): backbone configs for building network
        head (ConfigDict): head configs for building network
        state_dict (ConfigDict): initial network state dict received form learner
        device (str): literal to indicate cpu/cuda use

    """
    def __init__(
        self,
        rank: int,
        args: argparse.Namespace,
        env_info: ConfigDict,
        hyper_params: ConfigDict,
        backbone: ConfigDict,
        head: ConfigDict,
        state_dict: OrderedDict,
        device: str,
        loss_type: ConfigDict,
    ):
        DistributedWorker.__init__(self, rank, args, env_info, hyper_params,
                                   device)
        self.loss_fn = build_loss(loss_type)
        self.backbone_cfg = backbone
        self.head_cfg = head
        self.head_cfg.configs.state_size = self.env_info.observation_space.shape
        self.head_cfg.configs.output_size = self.env_info.action_space.n

        self.use_n_step = self.hyper_params.n_step > 1

        self.max_epsilon = self.hyper_params.max_epsilon
        self.min_epsilon = self.hyper_params.min_epsilon
        self.epsilon = self.hyper_params.max_epsilon

        self._init_networks(state_dict)

    # pylint: disable=attribute-defined-outside-init
    def _init_networks(self, state_dict: OrderedDict):
        """Initialize DQN policy with learner state dict."""
        self.dqn = Brain(self.backbone_cfg, self.head_cfg).to(self.device)
        self.dqn.load_state_dict(state_dict)
        self.dqn.eval()

    def load_params(self, path: str):
        """Load model and optimizer parameters."""
        DistributedWorker.load_params(self, path)

        params = torch.load(path)
        self.dqn.load_state_dict(params["dqn_state_dict"])
        print("[INFO] loaded the model and optimizer from", path)

    def select_action(self, state: np.ndarray) -> np.ndarray:
        """Select an action from the input space."""
        # epsilon greedy policy
        # pylint: disable=comparison-with-callable
        if self.epsilon > np.random.random():
            selected_action = np.array(self.env.action_space.sample())
        else:
            with torch.no_grad():
                state = self._preprocess_state(state, self.device)
                selected_action = self.dqn(state).argmax()
            selected_action = selected_action.cpu().numpy()

        # Decay epsilon
        self.epsilon = max(
            self.epsilon - (self.max_epsilon - self.min_epsilon) *
            self.hyper_params.epsilon_decay,
            self.min_epsilon,
        )

        return selected_action

    def step(self,
             action: np.ndarray) -> Tuple[np.ndarray, np.float64, bool, dict]:
        """Take an action and return the response of the env."""
        next_state, reward, done, info = self.env.step(action)
        return next_state, reward, done, info

    def compute_priorities(self, memory: Dict[str, np.ndarray]) -> np.ndarray:
        """Compute initial priority values of experiences in local memory."""
        states = numpy2floattensor(memory["states"], self.device)
        actions = numpy2floattensor(memory["actions"], self.device).long()
        rewards = numpy2floattensor(memory["rewards"].reshape(-1, 1),
                                    self.device)
        next_states = numpy2floattensor(memory["next_states"], self.device)
        dones = numpy2floattensor(memory["dones"].reshape(-1, 1), self.device)
        memory_tensors = (states, actions, rewards, next_states, dones)

        with torch.no_grad():
            dq_loss_element_wise, _ = self.loss_fn(
                self.dqn,
                self.dqn,
                memory_tensors,
                self.hyper_params.gamma,
                self.head_cfg,
            )
        loss_for_prior = dq_loss_element_wise.detach().cpu().numpy()
        new_priorities = loss_for_prior + self.hyper_params.per_eps
        return new_priorities

    def synchronize(self, new_state_dict: Dict[str, np.ndarray]):
        """Synchronize worker dqn with learner dqn."""
        self._synchronize(self.dqn, new_state_dict)