Example #1
0
    def _initialize(self):
        """Set env specific configs and build learner."""
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        if self.experiment_info.env.is_discrete:
            self.experiment_info.env.action_dim = self.env.action_space.n
        else:
            self.experiment_info.env.action_dim = self.env.action_space.shape[
                0]
            self.experiment_info.env.action_range = [
                self.env.action_space.low.tolist(),
                self.env.action_space.high.tolist(),
            ]

        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        self.action_selector = build_action_selector(self.experiment_info,
                                                     self.use_cuda)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)
Example #2
0
    def _initialize(self):
        """Initialize agent components."""
        # Define env specific model params
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        self.experiment_info.env.action_dim = self.env.action_space.shape[0]
        self.experiment_info.env.action_range = [
            self.env.action_space.low.tolist(),
            self.env.action_space.high.tolist(),
        ]

        # Build learner
        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params)

        # Build action selector
        self.action_selector = build_action_selector(self.experiment_info)
        self.action_selector = OUNoise(self.action_selector,
                                       self.env.action_space)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)
Example #3
0
    def _initialize(self):
        """Initialize agent components"""
        # set env specific model params
        self.model_cfg.params.model_cfg.state_dim = self.env.observation_space.shape
        self.model_cfg.params.model_cfg.action_dim = self.env.action_space.n

        # Build learner
        self.learner = build_learner(
            self.experiment_info, self.hyper_params, self.model_cfg
        )

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params
            )

        # Build action selector, wrap with e-greedy exploration
        self.action_selector = build_action_selector(self.experiment_info)
        self.action_selector = EpsGreedy(
            self.action_selector, self.env.action_space, self.hyper_params
        )

        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                )
            )
            self.logger = Logger(experiment_cfg)
Example #4
0
class SACAgent(Agent):
    """Soft Actor Critic agent

    Attributes:
        env (gym.ENV): Gym environment for RL agent
        learner (LEARNER): Carries and updates agent value/policy models
        replay_buffer (ReplayBuffer): Replay buffer for experience replay (PER as wrapper)
        action_selector (ActionSelector): Callable for DQN action selection (EpsGreedy as wrapper)
        use_n_step (bool): Indication of using n-step updates
        transition_queue (Deque): deque for tracking and preprocessing n-step transitions

    """
    def __init__(
        self,
        experiment_info: DictConfig,
        hyper_params: DictConfig,
        model_cfg: DictConfig,
    ):
        Agent.__init__(self, experiment_info, hyper_params, model_cfg)
        self.use_n_step = self.hyper_params.n_step > 1
        self.transition_queue = deque(maxlen=self.hyper_params.n_step)
        self.update_step = 0

        self._initialize()

    def _initialize(self):
        """Initialize agent components"""
        # Build env and env specific model params
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        self.experiment_info.env.action_dim = self.env.action_space.shape[0]
        self.experiment_info.env.action_range = [
            self.env.action_space.low.tolist(),
            self.env.action_space.high.tolist(),
        ]

        # Build learner
        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params)

        # Build action selector
        self.action_selector = build_action_selector(self.experiment_info,
                                                     self.use_cuda)
        self.action_selector = RandomActionsStarts(
            self.action_selector,
            max_exploratory_steps=self.hyper_params.max_exploratory_steps,
        )

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)

    def step(
        self, state: np.ndarray, action: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.float64, np.ndarray, bool]:
        """Carry out single env step and return info

        Params:
            state (np.ndarray): current env state
            action (np.ndarray): action to be executed

        """
        next_state, reward, done, _ = self.env.step(
            self.action_selector.rescale_action(action))
        return state, action, reward, next_state, done

    def train(self):
        """Main training loop"""
        step = 0
        for episode_i in range(self.experiment_info.total_num_episodes):
            # Test when we have to
            if episode_i % self.experiment_info.test_interval == 0:
                policy_copy = self.learner.get_policy(self.use_cuda)
                average_test_score = self.test(policy_copy,
                                               self.action_selector, episode_i,
                                               self.update_step)
                if self.experiment_info.log_wandb:
                    self.logger.write_log(
                        log_dict=dict(average_test_score=average_test_score), )
                self.learner.save_params()

            # Run episode
            state = self.env.reset()
            losses = dict(critic1_loss=[],
                          critic2_loss=[],
                          actor_loss=[],
                          alpha_loss=[])
            episode_reward = 0
            done = False

            while not done:
                if self.experiment_info.train_render:
                    self.env.render()

                action = self.action_selector(self.learner.actor, state,
                                              episode_i)
                state, action, reward, next_state, done = self.step(
                    state, action)

                episode_reward = episode_reward + reward
                step = step + 1

                if self.use_n_step:
                    transition = [state, action, reward, next_state, done]
                    self.transition_queue.append(transition)
                    if len(self.transition_queue) == self.hyper_params.n_step:
                        n_step_transition = preprocess_nstep(
                            self.transition_queue, self.hyper_params.gamma)
                        self.replay_buffer.add(*n_step_transition)
                else:
                    self.replay_buffer.add(state, action, reward, next_state,
                                           done)

                state = next_state

                if len(self.replay_buffer
                       ) >= self.hyper_params.update_starting_point:
                    experience = self.replay_buffer.sample()
                    info = self.learner.update_model(
                        self._preprocess_experience(experience))
                    self.update_step = self.update_step + 1
                    critic1_loss, critic2_loss, actor_loss, alpha_loss = info[:
                                                                              4]
                    losses["critic1_loss"].append(critic1_loss)
                    losses["critic2_loss"].append(critic2_loss)
                    losses["actor_loss"].append(actor_loss)
                    losses["alpha_loss"].append(alpha_loss)

                    if self.hyper_params.use_per:
                        indices, new_priorities = info[-2:]
                        self.replay_buffer.update_priorities(
                            indices, new_priorities)

            print(
                f"[TRAIN] episode num: {episode_i} | update step: {self.update_step} |"
                f" episode reward: {episode_reward}")

            if self.experiment_info.log_wandb:
                log_dict = dict(episode_reward=episode_reward)
                if self.update_step > 0:
                    log_dict["critic1_loss"] = np.mean(losses["critic1_loss"])
                    log_dict["critic2_loss"] = np.mean(losses["critic2_loss"])
                    log_dict["actor_loss"] = np.mean(losses["actor_loss"])
                    log_dict["alpha_loss"] = np.mean(losses["alpha_loss"])
                self.logger.write_log(log_dict)

    def _preprocess_experience(
            self, experience: Tuple[np.ndarray]) -> Tuple[torch.Tensor]:
        """Convert collected experience to pytorch tensors."""
        states, actions, rewards, next_states, dones = experience[:5]
        if self.hyper_params.use_per:
            indices, weights = experience[-2:]

        states = np2tensor(states, self.use_cuda)
        actions = np2tensor(actions, self.use_cuda)
        rewards = np2tensor(rewards.reshape(-1, 1), self.use_cuda)
        next_states = np2tensor(next_states, self.use_cuda)
        dones = np2tensor(dones.reshape(-1, 1), self.use_cuda)

        experience = (states, actions, rewards, next_states, dones)

        if self.hyper_params.use_per:
            weights = np2tensor(weights.reshape(-1, 1), self.use_cuda)
            experience = experience + (
                indices,
                weights,
            )

        return experience
Example #5
0
class DDPGAgent(Agent):
    """Configurable DQN base agent; works with Dueling DQN, C51, QR-DQN, etc.

    Attributes:
        env (gym.ENV): Gym environment for RL agent
        learner (LEARNER): Carries and updates agent value/policy models
        replay_buffer (ReplayBuffer): Replay buffer for experience replay (PER as wrapper)
        action_selector (ActionSelector): Callable for DQN action selection (EpsGreedy as wrapper)
        use_n_step (bool): Indication of using n-step updates
        transition_queue (Deque): deque for tracking and preprocessing n-step transitions

    """
    def __init__(
        self,
        experiment_info: DictConfig,
        hyper_params: DictConfig,
        model_cfg: DictConfig,
    ):
        Agent.__init__(self, experiment_info, hyper_params, model_cfg)
        self.use_n_step = self.hyper_params.n_step > 1
        self.transition_queue = deque(maxlen=self.hyper_params.n_step)
        self.update_step = 0

        self._initialize()

    def _initialize(self):
        """Initialize agent components."""
        # Define env specific model params
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        self.experiment_info.env.action_dim = self.env.action_space.shape[0]
        self.experiment_info.env.action_range = [
            self.env.action_space.low.tolist(),
            self.env.action_space.high.tolist(),
        ]

        # Build learner
        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params)

        # Build action selector
        self.action_selector = build_action_selector(self.experiment_info)
        self.action_selector = OUNoise(self.action_selector,
                                       self.env.action_space)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)

    def step(
            self, state: np.ndarray, action: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.float64, bool]:
        """Carry out single env step and return info."""
        next_state, reward, done, _ = self.env.step(action)
        return state, action, reward, next_state, done

    def train(self):
        """Main training loop."""
        step = 0
        for episode_i in range(self.experiment_info.total_num_episodes):
            state = self.env.reset()
            episode_reward = 0
            done = False

            while not done:
                if self.experiment_info.train_render:
                    self.env.render()

                action = self.action_selector(self.learner.actor, state)
                state, action, reward, next_state, done = self.step(
                    state, action)

                episode_reward = episode_reward + reward
                step = step + 1

                if self.use_n_step:
                    transition = [state, action, reward, next_state, done]
                    self.transition_queue.append(transition)
                    if len(self.transition_queue) == self.hyper_params.n_step:
                        n_step_transition = preprocess_nstep(
                            self.transition_queue, self.hyper_params.gamma)
                        self.replay_buffer.add(*n_step_transition)
                else:
                    self.replay_buffer.add(state, action, reward, next_state,
                                           done)

                state = next_state

                if len(self.replay_buffer
                       ) >= self.hyper_params.update_starting_point:
                    experience = self.replay_buffer.sample()
                    info = self.learner.update_model(
                        self._preprocess_experience(experience))
                    self.update_step = self.update_step + 1

                    if self.hyper_params.use_per:
                        indices, new_priorities = info[-2:]
                        self.replay_buffer.update_priorities(
                            indices, new_priorities)

                    if self.experiment_info.log_wandb:
                        self.logger.write_log(log_dict=dict(
                            critic1_loss=info[0],
                            critic2_loss=info[1],
                            actor_loss=info[2],
                        ), )

            print(
                f"[TRAIN] episode num: {episode_i} | update step: {self.update_step} |"
                f" episode reward: {episode_reward}")

            if self.experiment_info.log_wandb:
                self.logger.write_log(log_dict=dict(
                    episode_reward=episode_reward))

            if episode_i % self.experiment_info.test_interval == 0:
                policy_copy = self.learner.get_policy(self.device)
                average_test_score = self.test(policy_copy,
                                               self.action_selector, episode_i,
                                               self.update_step)
                if self.experiment_info.log_wandb:
                    self.logger.write_log(
                        log_dict=dict(average_test_score=average_test_score), )

                self.learner.save_params()

    def _preprocess_experience(self, experience: Tuple[np.ndarray]):
        """Convert collected experience to pytorch tensors."""
        states, actions, rewards, next_states, dones = experience[:5]
        if self.hyper_params.use_per:
            indices, weights = experience[-2:]

        states = np2tensor(states, self.device)
        actions = np2tensor(actions, self.device)
        rewards = np2tensor(rewards.reshape(-1, 1), self.device)
        next_states = np2tensor(next_states, self.device)
        dones = np2tensor(dones.reshape(-1, 1), self.device)

        experience = (states, actions, rewards, next_states, dones)

        if self.hyper_params.use_per:
            weights = np2tensor(weights.reshape(-1, 1), self.device)
            experience = experience + (
                indices,
                weights,
            )

        return experience
Example #6
0
class A2CAgent(Agent):
    """Synchronous Advantage Actor Critic (A2C; data parallel) agent

    Attributes:
        learner (Learner): learner for A2C
        update_step (int): update step counter
        action_selector (ActionSelector): action selector for testing
        logger (Logger): WandB logger

    """
    def __init__(
        self,
        experiment_info: DictConfig,
        hyper_params: DictConfig,
        model_cfg: DictConfig,
    ):
        Agent.__init__(self, experiment_info, hyper_params, model_cfg)
        self.update_step = 0

        self._initialize()

    def _initialize(self):
        """Set env specific configs and build learner."""
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        if self.experiment_info.env.is_discrete:
            self.experiment_info.env.action_dim = self.env.action_space.n
        else:
            self.experiment_info.env.action_dim = self.env.action_space.shape[
                0]
            self.experiment_info.env.action_range = [
                self.env.action_space.low.tolist(),
                self.env.action_space.high.tolist(),
            ]

        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        self.action_selector = build_action_selector(self.experiment_info,
                                                     self.use_cuda)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)

    def step(
        self, state: np.ndarray, action: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.float64, np.ndarray, bool]:
        """Carry out one environment step"""
        # A2C only uses this for test
        next_state, reward, done, _ = self.env.step(action)
        return state, action, reward, next_state, done

    def train(self):
        """Run data parellel training (A2C)."""
        ray.init()
        workers = []
        for worker_id in range(self.experiment_info.num_workers):
            worker = ray.remote(num_cpus=1)(TrajectoryRolloutWorker).remote(
                worker_id, self.experiment_info, self.learner.model_cfg.actor)
            workers.append(worker)

        print("Starting training...")
        time.sleep(1)
        while self.update_step < self.experiment_info.max_update_steps:
            # Run and retrieve trajectories
            trajectory_infos = ray.get(
                [worker.run_trajectory.remote() for worker in workers])

            # Run update step with multiple trajectories
            trajectories_tensor = [
                self._preprocess_trajectory(traj["trajectory"])
                for traj in trajectory_infos
            ]
            info = self.learner.update_model(trajectories_tensor)
            self.update_step = self.update_step + 1

            # Synchronize worker policies
            policy_state_dict = self.learner.actor.state_dict()
            for worker in workers:
                worker.synchronize_policy.remote(policy_state_dict)

            if self.experiment_info.log_wandb:
                worker_average_score = np.mean(
                    [traj["score"] for traj in trajectory_infos])
                log_dict = dict(episode_reward=worker_average_score)
                if self.update_step > 0:
                    log_dict["critic_loss"] = info[0]
                    log_dict["actor_loss"] = info[1]
                self.logger.write_log(log_dict)

            if self.update_step % self.experiment_info.test_interval == 0:
                policy_copy = self.learner.get_policy(self.use_cuda)
                average_test_score = self.test(
                    policy_copy,
                    self.action_selector,
                    self.update_step,
                    self.update_step,
                )
                if self.experiment_info.log_wandb:
                    self.logger.write_log(
                        log_dict=dict(average_test_score=average_test_score), )

                self.learner.save_params()

    def _preprocess_trajectory(
            self, trajectory: Tuple[np.ndarray, ...]) -> Tuple[torch.Tensor]:
        """Preprocess trajectory for pytorch training"""
        states, actions, rewards = trajectory

        states = np2tensor(states, self.use_cuda)
        actions = np2tensor(actions.reshape(-1, 1), self.use_cuda)
        rewards = np2tensor(rewards.reshape(-1, 1), self.use_cuda)

        if self.experiment_info.is_discrete:
            actions = actions.long()

        trajectory = (states, actions, rewards)

        return trajectory
Example #7
0
class A3CAgent(Agent):
    """Asynchronous Advantage Actor Critic (A3C) agent

    Attributes:
        learner (Learner): learner for A3C
        update_step (int): update step counter
        action_selector (ActionSelector): action selector for testing
        logger (Logger): WandB logger
    """
    def __init__(
        self,
        experiment_info: DictConfig,
        hyper_params: DictConfig,
        model_cfg: DictConfig,
    ):
        Agent.__init__(self, experiment_info, hyper_params, model_cfg)
        self.update_step = 0

        self._initialize()

    def _initialize(self):
        """Set env specific configs and build learner."""
        self.experiment_info.env.state_dim = self.env.observation_space.shape[
            0]
        if self.experiment_info.env.is_discrete:
            self.experiment_info.env.action_dim = self.env.action_space.n
        else:
            self.experiment_info.env.action_dim = self.env.action_space.shape[
                0]
            self.experiment_info.env.action_range = [
                self.env.action_space.low.tolist(),
                self.env.action_space.high.tolist(),
            ]

        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        self.action_selector = build_action_selector(self.experiment_info,
                                                     self.use_cuda)

        # Build logger
        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)

    def step(
        self, state: np.ndarray, action: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.float64, np.ndarray, bool]:
        """Carry out one environment step"""
        # A3C only uses this for test
        next_state, reward, done, _ = self.env.step(action)
        return state, action, reward, next_state, done

    def train(self):
        """Run gradient parallel training (A3C)."""
        ray.init()
        workers = []
        for worker_id in range(self.experiment_info.num_workers):
            worker = TrajectoryRolloutWorker(worker_id, self.experiment_info,
                                             self.learner.model_cfg.actor)

            # Wrap worker with ComputesGradients wrapper
            worker = ray.remote(num_cpus=1)(ComputesGradients).remote(
                worker, self.hyper_params, self.learner.model_cfg)
            workers.append(worker)

        gradients = {}
        for worker in workers:
            gradients[worker.compute_grads_with_traj.remote()] = worker

        while self.update_step < self.experiment_info.max_update_steps:
            computed_grads_ids, _ = ray.wait(list(gradients))
            if computed_grads_ids:
                # Retrieve computed gradients
                computed_grads, step_info = ray.get(computed_grads_ids[0])
                critic_grads, actor_grads = computed_grads

                # Apply computed gradients and update models
                self.learner.critic_optimizer.zero_grad()
                for param, grad in zip(self.learner.critic.parameters(),
                                       critic_grads):
                    param.grad = grad
                self.learner.critic_optimizer.step()

                self.learner.actor_optimizer.zero_grad()
                for param, grad in zip(self.learner.actor.parameters(),
                                       actor_grads):
                    param.grad = grad
                self.learner.actor_optimizer.step()

                self.update_step = self.update_step + 1

                # Synchronize worker models with updated models and get it runnin again
                state_dicts = dict(
                    critic=self.learner.critic.state_dict(),
                    actor=self.learner.actor.state_dict(),
                )
                worker = gradients.pop(computed_grads_ids[0])
                worker.synchronize.remote(state_dicts)
                gradients[worker.compute_grads_with_traj.remote()] = worker

                if self.experiment_info.log_wandb:
                    log_dict = dict()
                    if step_info["worker_rank"] == 0:
                        log_dict["Worker 0 score"] = step_info["score"]
                    if self.update_step > 0:
                        log_dict["critic_loss"] = step_info["critic_loss"]
                        log_dict["actor_loss"] = step_info["actor_loss"]
                    self.logger.write_log(log_dict)

            if self.update_step % self.experiment_info.test_interval == 0:
                policy_copy = self.learner.get_policy(self.use_cuda)
                average_test_score = self.test(
                    policy_copy,
                    self.action_selector,
                    self.update_step,
                    self.update_step,
                )
                if self.experiment_info.log_wandb:
                    self.logger.write_log(
                        log_dict=dict(average_test_score=average_test_score), )

                self.learner.save_params()

        ray.shut_down()
Example #8
0
class DQNBaseAgent(Agent):
    """Configurable DQN base agent; works with Dueling DQN, C51, QR-DQN, etc

    Attributes:
        env (gym.ENV): Gym environment for RL agent
        learner (LEARNER): Carries and updates agent value/policy models
        replay_buffer (ReplayBuffer): Replay buffer for experience replay (PER as wrapper)
        action_selector (ActionSelector): Callable for DQN action selection (EpsGreedy as wrapper)
        use_n_step (bool): Indication of using n-step updates
        transition_queue (Deque): deque for tracking and preprocessing n-step transitions
        logger (Logger): WandB logger
    """
    def __init__(
        self,
        experiment_info: DictConfig,
        hyper_params: DictConfig,
        model_cfg: DictConfig,
    ):
        Agent.__init__(self, experiment_info, hyper_params, model_cfg)
        self.use_n_step = self.hyper_params.n_step > 1
        self.transition_queue = deque(maxlen=self.hyper_params.n_step)
        self.update_step = 0

        self._initialize()

    def _initialize(self):
        """Initialize agent components"""
        # set env specific model params
        self.model_cfg.params.model_cfg.state_dim = self.env.observation_space.shape
        self.model_cfg.params.model_cfg.action_dim = self.env.action_space.n
        self.model_cfg.params.model_cfg.use_cuda = self.use_cuda

        # Build learner
        self.learner = build_learner(self.experiment_info, self.hyper_params,
                                     self.model_cfg)

        # Build replay buffer, wrap with PER buffer if using it
        self.replay_buffer = ReplayBuffer(self.hyper_params)
        if self.hyper_params.use_per:
            self.replay_buffer = PrioritizedReplayBuffer(
                self.replay_buffer, self.hyper_params)

        # Build action selector, wrap with e-greedy exploration
        self.action_selector = build_action_selector(self.experiment_info,
                                                     self.use_cuda)
        self.action_selector = EpsGreedy(self.action_selector,
                                         self.env.action_space,
                                         self.hyper_params)

        if self.experiment_info.log_wandb:
            experiment_cfg = OmegaConf.create(
                dict(
                    experiment_info=self.experiment_info,
                    hyper_params=self.hyper_params,
                    model=self.learner.model_cfg,
                ))
            self.logger = Logger(experiment_cfg)

    def step(
        self, state: np.ndarray, action: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.float64, np.ndarray, bool]:
        """Carry out single env step and return info"""
        next_state, reward, done, _ = self.env.step(action)
        return state, action, reward, next_state, done

    def train(self):
        """Main training loop"""
        step = 0
        for episode_i in range(self.experiment_info.total_num_episodes):
            # Test when we have to
            if episode_i % self.experiment_info.test_interval == 0:
                policy_copy = self.learner.get_policy(self.use_cuda)
                average_test_score = self.test(policy_copy,
                                               self.action_selector, episode_i,
                                               self.update_step)
                if self.experiment_info.log_wandb:
                    self.logger.write_log(
                        log_dict=dict(average_test_score=average_test_score), )
                self.learner.save_params()

            # Carry out episode
            state = self.env.reset()
            losses = []
            episode_reward = 0
            done = False

            while not done:
                if self.experiment_info.render_train:
                    self.env.render()

                action = self.action_selector(self.learner.network, state)
                state, action, reward, next_state, done = self.step(
                    state, action)
                episode_reward = episode_reward + reward
                step = step + 1

                if self.use_n_step:
                    transition = [state, action, reward, next_state, done]
                    self.transition_queue.append(transition)
                    if len(self.transition_queue) == self.hyper_params.n_step:
                        n_step_transition = preprocess_nstep(
                            self.transition_queue, self.hyper_params.gamma)
                        self.replay_buffer.add(*n_step_transition)
                else:
                    self.replay_buffer.add(state, action, reward, next_state,
                                           done)

                state = next_state

                if len(self.replay_buffer
                       ) >= self.hyper_params.update_starting_point:
                    if step % self.hyper_params.train_freq == 0:
                        experience = self.replay_buffer.sample()

                        info = self.learner.update_model(
                            self._preprocess_experience(experience))
                        self.update_step = self.update_step + 1
                        losses.append(info[0])

                        if self.hyper_params.use_per:
                            indices, new_priorities = info[-2:]
                            self.replay_buffer.update_priorities(
                                indices, new_priorities)

                self.action_selector.decay_epsilon()

            print(f"[TRAIN] episode num: {episode_i} "
                  f"| update step: {self.update_step} "
                  f"| episode reward: {episode_reward} "
                  f"| epsilon: {round(self.action_selector.eps, 5)}")

            if self.experiment_info.log_wandb:
                log_dict = dict(episode_reward=episode_reward,
                                epsilon=self.action_selector.eps)
                if self.update_step > 0:
                    log_dict["mean_loss"] = np.mean(losses)
                self.logger.write_log(log_dict=log_dict)

    def _preprocess_experience(
            self, experience: Tuple[np.ndarray]) -> Tuple[torch.Tensor]:
        """Convert numpy experiences to tensor: MEMORY """
        states, actions, rewards, next_states, dones = experience[:5]
        if self.hyper_params.use_per:
            indices, weights = experience[-2:]

        states = np2tensor(states, self.use_cuda)
        actions = np2tensor(actions.reshape(-1, 1), self.use_cuda)
        rewards = np2tensor(rewards.reshape(-1, 1), self.use_cuda)
        next_states = np2tensor(next_states, self.use_cuda)
        dones = np2tensor(dones.reshape(-1, 1), self.use_cuda)

        experience = (states, actions.long(), rewards, next_states, dones)

        if self.hyper_params.use_per:
            weights = np2tensor(weights, self.use_cuda)
            experience = experience + (
                indices,
                weights,
            )

        return experience