Exemplo n.º 1
0
class TD3(AttributeSavingMixin, BatchAgent):
    """Twin Delayed Deep Deterministic Policy Gradients (TD3).

    See http://arxiv.org/abs/1802.09477

    Args:
        policy (Policy): Policy.
        q_func1 (Module): First Q-function that takes state-action pairs as input
            and outputs predicted Q-values.
        q_func2 (Module): Second Q-function that takes state-action pairs as
            input and outputs predicted Q-values.
        policy_optimizer (Optimizer): Optimizer setup with the policy
        q_func1_optimizer (Optimizer): Optimizer setup with the first
            Q-function.
        q_func2_optimizer (Optimizer): Optimizer setup with the second
            Q-function.
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        phi (callable): Feature extractor applied to observations
        soft_update_tau (float): Tau of soft target update.
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `pfrl.utils.batch_states.batch_states`
        burnin_action_func (callable or None): If not None, this callable
            object is used to select actions before the model is updated
            one or more times during training.
        policy_update_delay (int): Delay of policy updates. Policy is updated
            once in `policy_update_delay` times of Q-function updates.
        target_policy_smoothing_func (callable): Callable that takes a batch of
            actions as input and outputs a noisy version of it. It is used for
            target policy smoothing when computing target Q-values.
    """

    saved_attributes = (
        "policy",
        "q_func1",
        "q_func2",
        "target_policy",
        "target_q_func1",
        "target_q_func2",
        "policy_optimizer",
        "q_func1_optimizer",
        "q_func2_optimizer",
    )

    def __init__(
        self,
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=10000,
        minibatch_size=100,
        update_interval=1,
        phi=lambda x: x,
        soft_update_tau=5e-3,
        n_times_update=1,
        max_grad_norm=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
        policy_update_delay=2,
        target_policy_smoothing_func=default_target_policy_smoothing_func,
    ):

        self.policy = policy
        self.q_func1 = q_func1
        self.q_func2 = q_func2

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.policy.to(self.device)
            self.q_func1.to(self.device)
            self.q_func2.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.phi = phi
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.policy_optimizer = policy_optimizer
        self.q_func1_optimizer = q_func1_optimizer
        self.q_func2_optimizer = q_func2_optimizer
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            n_times_update=1,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
            episodic_update=False,
        )
        self.max_grad_norm = max_grad_norm
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func
        self.policy_update_delay = policy_update_delay
        self.target_policy_smoothing_func = target_policy_smoothing_func

        self.t = 0
        self.policy_n_updates = 0
        self.q_func_n_updates = 0
        self.last_state = None
        self.last_action = None

        # Target model
        self.target_policy = copy.deepcopy(
            self.policy).eval().requires_grad_(False)
        self.target_q_func1 = copy.deepcopy(
            self.q_func1).eval().requires_grad_(False)
        self.target_q_func2 = copy.deepcopy(
            self.q_func2).eval().requires_grad_(False)

        # Statistics
        self.q1_record = collections.deque(maxlen=1000)
        self.q2_record = collections.deque(maxlen=1000)
        self.q_func1_loss_record = collections.deque(maxlen=100)
        self.q_func2_loss_record = collections.deque(maxlen=100)
        self.policy_loss_record = collections.deque(maxlen=100)

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(
            src=self.policy,
            dst=self.target_policy,
            method="soft",
            tau=self.soft_update_tau,
        )
        synchronize_parameters(
            src=self.q_func1,
            dst=self.target_q_func1,
            method="soft",
            tau=self.soft_update_tau,
        )
        synchronize_parameters(
            src=self.q_func2,
            dst=self.target_q_func2,
            method="soft",
            tau=self.soft_update_tau,
        )

    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch["next_state"]
        batch_rewards = batch["reward"]
        batch_terminal = batch["is_state_terminal"]
        batch_state = batch["state"]
        batch_actions = batch["action"]
        batch_discount = batch["discount"]

        with torch.no_grad(), pfrl.utils.evaluating(
                self.target_policy), pfrl.utils.evaluating(
                    self.target_q_func1), pfrl.utils.evaluating(
                        self.target_q_func2):
            next_actions = self.target_policy_smoothing_func(
                self.target_policy(batch_next_state).sample())
            next_q1 = self.target_q_func1((batch_next_state, next_actions))
            next_q2 = self.target_q_func2((batch_next_state, next_actions))
            next_q = torch.min(next_q1, next_q2)

            target_q = batch_rewards + batch_discount * (
                1.0 - batch_terminal) * torch.flatten(next_q)

        predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions)))
        predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions)))

        loss1 = F.mse_loss(target_q, predict_q1)
        loss2 = F.mse_loss(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(predict_q1.detach().cpu().numpy())
        self.q2_record.extend(predict_q2.detach().cpu().numpy())
        self.q_func1_loss_record.append(float(loss1))
        self.q_func2_loss_record.append(float(loss2))

        self.q_func1_optimizer.zero_grad()
        loss1.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
        self.q_func1_optimizer.step()

        self.q_func2_optimizer.zero_grad()
        loss2.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
        self.q_func2_optimizer.step()

        self.q_func_n_updates += 1

    def update_policy(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]

        onpolicy_actions = self.policy(batch_state).rsample()
        q = self.q_func1((batch_state, onpolicy_actions))

        # Since we want to maximize Q, loss is negation of Q
        loss = -torch.mean(q)

        self.policy_loss_record.append(float(loss))
        self.policy_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()
        self.policy_n_updates += 1

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.device, self.phi,
                                  self.gamma)
        self.update_q_func(batch)
        if self.q_func_n_updates % self.policy_update_delay == 0:
            self.update_policy(batch)
            self.sync_target_network()

    def batch_select_onpolicy_action(self, batch_obs):
        with torch.no_grad(), pfrl.utils.evaluating(self.policy):
            batch_xs = self.batch_states(batch_obs, self.device, self.phi)
            batch_action = self.policy(batch_xs).sample().cpu().numpy()
        return list(batch_action)

    def batch_act(self, batch_obs):
        if self.training:
            return self._batch_act_train(batch_obs)
        else:
            return self._batch_act_eval(batch_obs)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        if self.training:
            self._batch_observe_train(batch_obs, batch_reward, batch_done,
                                      batch_reset)

    def _batch_act_eval(self, batch_obs):
        assert not self.training
        return self.batch_select_onpolicy_action(batch_obs)

    def _batch_act_train(self, batch_obs):
        assert self.training
        if self.burnin_action_func is not None and self.policy_n_updates == 0:
            batch_action = [
                self.burnin_action_func() for _ in range(len(batch_obs))
            ]
        else:
            batch_onpolicy_action = self.batch_select_onpolicy_action(
                batch_obs)
            batch_action = [
                self.explorer.select_action(self.t,
                                            lambda: batch_onpolicy_action[i])
                for i in range(len(batch_onpolicy_action))
            ]

        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def _batch_observe_train(self, batch_obs, batch_reward, batch_done,
                             batch_reset):
        assert self.training
        for i in range(len(batch_obs)):
            self.t += 1
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def get_statistics(self):
        return [
            ("average_q1", _mean_or_nan(self.q1_record)),
            ("average_q2", _mean_or_nan(self.q2_record)),
            ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)),
            ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)),
            ("average_policy_loss", _mean_or_nan(self.policy_loss_record)),
            ("policy_n_updates", self.policy_n_updates),
            ("q_func_n_updates", self.q_func_n_updates),
        ]
Exemplo n.º 2
0
    def __init__(
        self,
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=10000,
        minibatch_size=100,
        update_interval=1,
        phi=lambda x: x,
        soft_update_tau=5e-3,
        n_times_update=1,
        max_grad_norm=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
        policy_update_delay=2,
        target_policy_smoothing_func=default_target_policy_smoothing_func,
    ):

        self.policy = policy
        self.q_func1 = q_func1
        self.q_func2 = q_func2

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.policy.to(self.device)
            self.q_func1.to(self.device)
            self.q_func2.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.phi = phi
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.policy_optimizer = policy_optimizer
        self.q_func1_optimizer = q_func1_optimizer
        self.q_func2_optimizer = q_func2_optimizer
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            n_times_update=1,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
            episodic_update=False,
        )
        self.max_grad_norm = max_grad_norm
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func
        self.policy_update_delay = policy_update_delay
        self.target_policy_smoothing_func = target_policy_smoothing_func

        self.t = 0
        self.policy_n_updates = 0
        self.q_func_n_updates = 0
        self.last_state = None
        self.last_action = None

        # Target model
        self.target_policy = copy.deepcopy(
            self.policy).eval().requires_grad_(False)
        self.target_q_func1 = copy.deepcopy(
            self.q_func1).eval().requires_grad_(False)
        self.target_q_func2 = copy.deepcopy(
            self.q_func2).eval().requires_grad_(False)

        # Statistics
        self.q1_record = collections.deque(maxlen=1000)
        self.q2_record = collections.deque(maxlen=1000)
        self.q_func1_loss_record = collections.deque(maxlen=100)
        self.q_func2_loss_record = collections.deque(maxlen=100)
        self.policy_loss_record = collections.deque(maxlen=100)
Exemplo n.º 3
0
class DQN(agent.AttributeSavingMixin, agent.BatchAgent):
    """Deep Q-Network algorithm.

    Args:
        q_function (StateQFunction): Q-function
        optimizer (Optimizer): Optimizer that is already setup
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        clip_delta (bool): Clip delta if set True
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `pfrl.utils.batch_states.batch_states`
        recurrent (bool): If set to True, `model` is assumed to implement
            `pfrl.nn.Recurrent` and is updated in a recurrent
            manner.
        max_grad_norm (float or None): Maximum L2 norm of the gradient used for
            gradient clipping. If set to None, the gradient is not clipped.
    """

    saved_attributes = ("model", "target_model", "optimizer")

    def __init__(
        self,
        q_function,
        optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=50000,
        minibatch_size=32,
        update_interval=1,
        target_update_interval=10000,
        clip_delta=True,
        phi=lambda x: x,
        target_update_method="hard",
        soft_update_tau=1e-2,
        n_times_update=1,
        batch_accumulator="mean",
        episodic_update_len=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        recurrent=False,
        max_grad_norm=None,
    ):
        self.rnd_reward = 0
        self.ngu_reward = 0
        self.model = q_function

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ("mean", "sum")
        self.logger = logger
        self.batch_states = batch_states
        self.recurrent = recurrent
        if self.recurrent:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.minibatch_size = minibatch_size
        self.episodic_update_len = episodic_update_len
        self.replay_start_size = replay_start_size
        self.update_interval = update_interval
        self.max_grad_norm = max_grad_norm

        assert (
            target_update_interval % update_interval == 0
        ), "target_update_interval should be a multiple of update_interval"

        self.t = 0
        self.optim_t = 0  # Compensate pytorch optim not having `t`
        self._cumulative_steps = 0
        self.last_state = None
        self.last_action = None
        self.target_model = None
        self.sync_target_network()

        # Statistics
        self.q_record = collections.deque(maxlen=1000)
        self.loss_record = collections.deque(maxlen=100)

        # Recurrent states of the model
        self.train_recurrent_states = None
        self.train_prev_recurrent_states = None
        self.test_recurrent_states = None

        self.replay_buffer_lock = None

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError(
                "Replay start size cannot exceed replay buffer capacity.")

    def set_rnd_module(self, rnd_module):
        self.rnd_module = rnd_module
        self.rnd_reward = 1

    def set_ngu_module(self, ngu_module):
        self.ngu_module = ngu_module
        self.ngu_reward = 1

    @property
    def cumulative_steps(self):
        # cumulative_steps counts the overall steps during the training.
        return self._cumulative_steps

    def _setup_actor_learner_training(self, n_actors, actor_update_interval,
                                      update_counter):
        assert actor_update_interval > 0

        self.actor_update_interval = actor_update_interval
        self.update_counter = update_counter

        # Make a copy on shared memory and share among actors and the poller
        shared_model = copy.deepcopy(self.model).cpu()
        shared_model.share_memory()

        # Pipes are used for infrequent communication
        learner_pipes, actor_pipes = list(
            zip(*[mp.Pipe() for _ in range(n_actors)]))

        return (shared_model, learner_pipes, actor_pipes)

    def sync_target_network(self):
        """Synchronize target network with current network."""
        if self.target_model is None:
            self.target_model = copy.deepcopy(self.model)

            def flatten_parameters(mod):
                if isinstance(mod, torch.nn.RNNBase):
                    mod.flatten_parameters()

            # RNNBase.flatten_parameters must be called again after deep-copy.
            # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506  # NOQA
            self.target_model.apply(flatten_parameters)
            # set target n/w to evaluate only.
            self.target_model.eval()
        else:
            synchronize_parameters(
                src=self.model,
                dst=self.target_model,
                method=self.target_update_method,
                tau=self.soft_update_tau,
            )

    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None
        """
        has_weight = "weight" in experiences[0][0]
        exp_batch = batch_experiences(
            experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        if self.rnd_reward:
            self.rnd_module.train(exp_batch)
        # if self.ngu_reward:
        #     self.ngu_module.train(exp_batch)
        if has_weight:
            exp_batch["weights"] = torch.tensor(  # pylint: disable=not-callable
                [elem[0]["weight"] for elem in experiences],
                device=self.device,
                dtype=torch.float32,
            )
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        self.loss_record.append(float(loss.detach().cpu().numpy()))

        self.optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            pfrl.utils.clip_l2_grad_norm_(self.model.parameters(),
                                          self.max_grad_norm)
        self.optimizer.step()
        self.optim_t += 1

    def update_from_episodes(self, episodes, errors_out=None):
        assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer"
        episodes = sorted(episodes, key=len, reverse=True)
        exp_batch = batch_recurrent_experiences(
            episodes,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        loss = self._compute_loss(exp_batch, errors_out=None)
        self.loss_record.append(float(loss.detach().cpu().numpy()))
        self.optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            pfrl.utils.clip_l2_grad_norm_(self.model.parameters(),
                                          self.max_grad_norm)
        self.optimizer.step()
        self.optim_t += 1

    def _compute_target_values(self, exp_batch):
        batch_next_state = exp_batch["next_state"]

        if self.recurrent:
            target_next_qout, _ = pack_and_forward(
                self.target_model,
                batch_next_state,
                exp_batch["next_recurrent_state"],
            )
        else:
            target_next_qout = self.target_model(batch_next_state)
        next_q_max = target_next_qout.max

        batch_terminal = exp_batch["is_state_terminal"]
        discount = exp_batch["discount"]
        batch_rewards = exp_batch["reward"]

        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max

    def _compute_y_and_t(self, exp_batch):
        batch_size = exp_batch["reward"].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch["state"]

        if self.recurrent:
            qout, _ = pack_and_forward(self.model, batch_state,
                                       exp_batch["recurrent_state"])
        else:
            qout = self.model(batch_state)

        batch_actions = exp_batch["action"]
        batch_q = torch.reshape(qout.evaluate_actions(batch_actions),
                                (batch_size, 1))

        with torch.no_grad():
            batch_q_target = torch.reshape(
                self._compute_target_values(exp_batch), (batch_size, 1))

        return batch_q, batch_q_target

    def _compute_loss(self, exp_batch, errors_out=None):
        """Compute the Q-learning loss for a batch of experiences


        Args:
          exp_batch (dict): A dict of batched arrays of transitions
        Returns:
          Computed loss from the minibatch of experiences
        """
        y, t = self._compute_y_and_t(exp_batch)

        self.q_record.extend(y.detach().cpu().numpy().ravel())

        if errors_out is not None:
            del errors_out[:]
            delta = torch.abs(y - t)
            if delta.ndim == 2:
                delta = torch.sum(delta, dim=1)
            delta = delta.detach().cpu().numpy()
            for e in delta:
                errors_out.append(e)

        if "weights" in exp_batch:
            return compute_weighted_value_loss(
                y,
                t,
                exp_batch["weights"],
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator,
            )
        else:
            return compute_value_loss(
                y,
                t,
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator,
            )

    def _evaluate_model_and_update_recurrent_states(self, batch_obs):
        batch_xs = self.batch_states(batch_obs, self.device, self.phi)
        if self.recurrent:
            if self.training:
                self.train_prev_recurrent_states = self.train_recurrent_states
                batch_av, self.train_recurrent_states = one_step_forward(
                    self.model, batch_xs, self.train_recurrent_states)
            else:
                batch_av, self.test_recurrent_states = one_step_forward(
                    self.model, batch_xs, self.test_recurrent_states)
        else:
            batch_av = self.model(batch_xs)
        return batch_av

    def batch_act(self, batch_obs):
        with torch.no_grad(), evaluating(self.model):
            batch_av = self._evaluate_model_and_update_recurrent_states(
                batch_obs)
            batch_argmax = batch_av.greedy_actions.cpu().numpy()
        if self.training:
            batch_action = [
                self.explorer.select_action(
                    self.t,
                    lambda: batch_argmax[i],
                    action_value=batch_av[i:i + 1],
                ) for i in range(len(batch_obs))
            ]
            self.batch_last_obs = list(batch_obs)
            self.batch_last_action = list(batch_action)
        else:
            batch_action = batch_argmax
        return batch_action

    def _batch_observe_train(self, batch_obs, batch_reward, batch_done,
                             batch_reset):

        for i in range(len(batch_obs)):
            self.t += 1
            self._cumulative_steps += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    "state": self.batch_last_obs[i],
                    "action": self.batch_last_action[i],
                    "reward": batch_reward[i],
                    "next_state": batch_obs[i],
                    "next_action": None,
                    "is_state_terminal": batch_done[i],
                }
                if self.recurrent:
                    transition["recurrent_state"] = recurrent_state_as_numpy(
                        get_recurrent_state_at(
                            self.train_prev_recurrent_states, i, detach=True))
                    transition[
                        "next_recurrent_state"] = recurrent_state_as_numpy(
                            get_recurrent_state_at(self.train_recurrent_states,
                                                   i,
                                                   detach=True))
                self.replay_buffer.append(env_id=i, **transition)
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.train_prev_recurrent_states = None
            self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.train_recurrent_states,
            )

    def _batch_observe_eval(self, batch_obs, batch_reward, batch_done,
                            batch_reset):
        if self.recurrent:
            # Reset recurrent states when episodes end
            self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.test_recurrent_states,
            )

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        if self.training:
            return self._batch_observe_train(batch_obs, batch_reward,
                                             batch_done, batch_reset)
        else:
            return self._batch_observe_eval(batch_obs, batch_reward,
                                            batch_done, batch_reset)

    def _can_start_replay(self):
        if len(self.replay_buffer) < self.replay_start_size:
            return False
        if self.recurrent and self.replay_buffer.n_episodes < self.minibatch_size:
            return False
        return True

    def _poll_pipe(self, actor_idx, pipe, replay_buffer_lock, exception_event):
        if pipe.closed:
            return
        try:
            while pipe.poll() and not exception_event.is_set():
                cmd, data = pipe.recv()
                if cmd == "get_statistics":
                    assert data is None
                    with replay_buffer_lock:
                        stats = self.get_statistics()
                    pipe.send(stats)
                elif cmd == "load":
                    self.load(data)
                    pipe.send(None)
                elif cmd == "save":
                    self.save(data)
                    pipe.send(None)
                elif cmd == "transition":
                    with replay_buffer_lock:
                        if "env_id" not in data:
                            data["env_id"] = actor_idx
                        self.replay_buffer.append(**data)
                        self._cumulative_steps += 1
                elif cmd == "stop_episode":
                    idx = actor_idx if data is None else data
                    with replay_buffer_lock:
                        self.replay_buffer.stop_current_episode(env_id=idx)
                        stats = self.get_statistics()
                    pipe.send(stats)

                else:
                    raise RuntimeError(
                        "Unknown command from actor: {}".format(cmd))
        except EOFError:
            pipe.close()
        except Exception:
            self.logger.exception("Poller loop failed. Exiting")
            exception_event.set()

    def _learner_loop(
        self,
        shared_model,
        pipes,
        replay_buffer_lock,
        stop_event,
        exception_event,
        n_updates=None,
    ):
        try:
            update_counter = 0
            # To stop this loop, call stop_event.set()
            while not stop_event.is_set():
                # Update model if possible
                if not self._can_start_replay():
                    continue
                if n_updates is not None:
                    assert self.optim_t <= n_updates
                    if self.optim_t == n_updates:
                        stop_event.set()
                        break

                if self.recurrent:
                    with replay_buffer_lock:
                        episodes = self.replay_buffer.sample_episodes(
                            self.minibatch_size, self.episodic_update_len)
                    self.update_from_episodes(episodes)
                else:
                    with replay_buffer_lock:
                        transitions = self.replay_buffer.sample(
                            self.minibatch_size)
                    self.update(transitions)

                # Update the shared model. This can be expensive if GPU is used
                # since this is a DtoH copy, so it is updated only at regular
                # intervals.
                update_counter += 1
                if update_counter % self.actor_update_interval == 0:
                    with self.update_counter.get_lock():
                        self.update_counter.value += 1
                        shared_model.load_state_dict(self.model.state_dict())

                # To keep the ratio of target updates to model updates,
                # here we calculate back the effective current timestep
                # from update_interval and number of updates so far.
                effective_timestep = self.optim_t * self.update_interval
                # We can safely assign self.t since in the learner
                # it isn't updated by any other method
                self.t = effective_timestep
                if effective_timestep % self.target_update_interval == 0:
                    self.sync_target_network()
        except Exception:
            self.logger.exception("Learner loop failed. Exiting")
            exception_event.set()

    def _poller_loop(self, shared_model, pipes, replay_buffer_lock, stop_event,
                     exception_event):
        # To stop this loop, call stop_event.set()
        while not stop_event.is_set() and not exception_event.is_set():
            time.sleep(1e-6)
            # Poll actors for messages
            for i, pipe in enumerate(pipes):
                self._poll_pipe(i, pipe, replay_buffer_lock, exception_event)

    def setup_actor_learner_training(self,
                                     n_actors,
                                     update_counter=None,
                                     n_updates=None,
                                     actor_update_interval=8):
        if update_counter is None:
            update_counter = mp.Value(ctypes.c_ulong)

        (shared_model, learner_pipes,
         actor_pipes) = self._setup_actor_learner_training(
             n_actors, actor_update_interval, update_counter)
        exception_event = mp.Event()

        def make_actor(i):
            return pfrl.agents.StateQFunctionActor(
                pipe=actor_pipes[i],
                model=shared_model,
                explorer=self.explorer,
                phi=self.phi,
                batch_states=self.batch_states,
                logger=self.logger,
                recurrent=self.recurrent,
            )

        replay_buffer_lock = mp.Lock()
        self.replay_buffer_lock = replay_buffer_lock

        poller_stop_event = mp.Event()
        poller = pfrl.utils.StoppableThread(
            target=self._poller_loop,
            kwargs=dict(
                shared_model=shared_model,
                pipes=learner_pipes,
                replay_buffer_lock=replay_buffer_lock,
                stop_event=poller_stop_event,
                exception_event=exception_event,
            ),
            stop_event=poller_stop_event,
        )

        learner_stop_event = mp.Event()
        learner = pfrl.utils.StoppableThread(
            target=self._learner_loop,
            kwargs=dict(
                shared_model=shared_model,
                pipes=learner_pipes,
                replay_buffer_lock=replay_buffer_lock,
                stop_event=learner_stop_event,
                n_updates=n_updates,
                exception_event=exception_event,
            ),
            stop_event=learner_stop_event,
        )

        return make_actor, learner, poller, exception_event

    def stop_episode(self):
        if self.recurrent:
            self.test_recurrent_states = None

    def get_statistics(self):
        return [
            ("average_q", _mean_or_nan(self.q_record)),
            ("average_loss", _mean_or_nan(self.loss_record)),
            ("cumulative_steps", self.cumulative_steps),
            ("n_updates", self.optim_t),
            ("rlen", len(self.replay_buffer)),
        ]
Exemplo n.º 4
0
    def __init__(
        self,
        q_function,
        optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=50000,
        minibatch_size=32,
        update_interval=1,
        target_update_interval=10000,
        clip_delta=True,
        phi=lambda x: x,
        target_update_method="hard",
        soft_update_tau=1e-2,
        n_times_update=1,
        batch_accumulator="mean",
        episodic_update_len=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        recurrent=False,
        max_grad_norm=None,
    ):
        self.rnd_reward = 0
        self.ngu_reward = 0
        self.model = q_function

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ("mean", "sum")
        self.logger = logger
        self.batch_states = batch_states
        self.recurrent = recurrent
        if self.recurrent:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.minibatch_size = minibatch_size
        self.episodic_update_len = episodic_update_len
        self.replay_start_size = replay_start_size
        self.update_interval = update_interval
        self.max_grad_norm = max_grad_norm

        assert (
            target_update_interval % update_interval == 0
        ), "target_update_interval should be a multiple of update_interval"

        self.t = 0
        self.optim_t = 0  # Compensate pytorch optim not having `t`
        self._cumulative_steps = 0
        self.last_state = None
        self.last_action = None
        self.target_model = None
        self.sync_target_network()

        # Statistics
        self.q_record = collections.deque(maxlen=1000)
        self.loss_record = collections.deque(maxlen=100)

        # Recurrent states of the model
        self.train_recurrent_states = None
        self.train_prev_recurrent_states = None
        self.test_recurrent_states = None

        self.replay_buffer_lock = None

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError(
                "Replay start size cannot exceed replay buffer capacity.")
Exemplo n.º 5
0
class SoftActorCritic(AttributeSavingMixin, BatchAgent):
    """Soft Actor-Critic (SAC).

    See https://arxiv.org/abs/1812.05905

    Args:
        policy (Policy): Policy.
        q_func1 (Module): First Q-function that takes state-action pairs as input
            and outputs predicted Q-values.
        q_func2 (Module): Second Q-function that takes state-action pairs as
            input and outputs predicted Q-values.
        policy_optimizer (Optimizer): Optimizer setup with the policy
        q_func1_optimizer (Optimizer): Optimizer setup with the first
            Q-function.
        q_func2_optimizer (Optimizer): Optimizer setup with the second
            Q-function.
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        phi (callable): Feature extractor applied to observations
        soft_update_tau (float): Tau of soft target update.
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `pfrl.utils.batch_states.batch_states`
        burnin_action_func (callable or None): If not None, this callable
            object is used to select actions before the model is updated
            one or more times during training.
        initial_temperature (float): Initial temperature value. If
            `entropy_target` is set to None, the temperature is fixed to it.
        entropy_target (float or None): If set to a float, the temperature is
            adjusted during training to match the policy's entropy to it.
        temperature_optimizer_lr (float): Learning rate of the temperature
            optimizer. If set to None, Adam with default hyperparameters
            is used.
        act_deterministically (bool): If set to True, choose most probable
            actions in the act method instead of sampling from distributions.
    """

    saved_attributes = (
        "policy",
        "q_func1",
        "q_func2",
        "target_q_func1",
        "target_q_func2",
        "policy_optimizer",
        "q_func1_optimizer",
        "q_func2_optimizer",
        "temperature_holder",
        "temperature_optimizer",
    )

    def __init__(
        self,
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        replay_buffer,
        gamma,
        gpu=None,
        replay_start_size=10000,
        minibatch_size=100,
        update_interval=1,
        phi=lambda x: x,
        soft_update_tau=5e-3,
        max_grad_norm=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
        initial_temperature=1.0,
        entropy_target=None,
        temperature_optimizer_lr=None,
        act_deterministically=True,
    ):

        self.policy = policy
        self.q_func1 = q_func1
        self.q_func2 = q_func2

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.policy.to(self.device)
            self.q_func1.to(self.device)
            self.q_func2.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.gpu = gpu
        self.phi = phi
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.policy_optimizer = policy_optimizer
        self.q_func1_optimizer = q_func1_optimizer
        self.q_func2_optimizer = q_func2_optimizer
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            n_times_update=1,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
            episodic_update=False,
        )
        self.max_grad_norm = max_grad_norm
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func
        self.initial_temperature = initial_temperature
        self.entropy_target = entropy_target
        if self.entropy_target is not None:
            self.temperature_holder = TemperatureHolder(
                initial_log_temperature=np.log(initial_temperature)
            )
            if temperature_optimizer_lr is not None:
                self.temperature_optimizer = torch.optim.Adam(
                    self.temperature_holder.parameters(), lr=temperature_optimizer_lr
                )
            else:
                self.temperature_optimizer = torch.optim.Adam(
                    self.temperature_holder.parameters()
                )
            if gpu is not None and gpu >= 0:
                self.temperature_holder.to(self.device)
        else:
            self.temperature_holder = None
            self.temperature_optimizer = None
        self.act_deterministically = act_deterministically

        self.t = 0

        # Target model
        self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False)
        self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False)

        # Statistics
        self.q1_record = collections.deque(maxlen=1000)
        self.q2_record = collections.deque(maxlen=1000)
        self.entropy_record = collections.deque(maxlen=1000)
        self.q_func1_loss_record = collections.deque(maxlen=100)
        self.q_func2_loss_record = collections.deque(maxlen=100)
        self.n_policy_updates = 0

    @property
    def temperature(self):
        if self.entropy_target is None:
            return self.initial_temperature
        else:
            with torch.no_grad():
                return float(self.temperature_holder())

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(
            src=self.q_func1,
            dst=self.target_q_func1,
            method="soft",
            tau=self.soft_update_tau,
        )
        synchronize_parameters(
            src=self.q_func2,
            dst=self.target_q_func2,
            method="soft",
            tau=self.soft_update_tau,
        )

    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch["next_state"]
        batch_rewards = batch["reward"]
        batch_terminal = batch["is_state_terminal"]
        batch_state = batch["state"]
        batch_actions = batch["action"]
        batch_discount = batch["discount"]

        with torch.no_grad(), pfrl.utils.evaluating(self.policy), pfrl.utils.evaluating(
            self.target_q_func1
        ), pfrl.utils.evaluating(self.target_q_func2):
            next_action_distrib = self.policy(batch_next_state)
            next_actions = next_action_distrib.sample()
            next_log_prob = next_action_distrib.log_prob(next_actions)
            next_q1 = self.target_q_func1((batch_next_state, next_actions))
            next_q2 = self.target_q_func2((batch_next_state, next_actions))
            next_q = torch.min(next_q1, next_q2)
            entropy_term = self.temperature * next_log_prob[..., None]
            assert next_q.shape == entropy_term.shape

            target_q = batch_rewards + batch_discount * (
                1.0 - batch_terminal
            ) * torch.flatten(next_q - entropy_term)

        predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions)))
        predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions)))

        loss1 = 0.5 * F.mse_loss(target_q, predict_q1)
        loss2 = 0.5 * F.mse_loss(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(predict_q1.detach().cpu().numpy())
        self.q2_record.extend(predict_q2.detach().cpu().numpy())
        self.q_func1_loss_record.append(float(loss1))
        self.q_func2_loss_record.append(float(loss2))

        self.q_func1_optimizer.zero_grad()
        loss1.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm)
        self.q_func1_optimizer.step()

        self.q_func2_optimizer.zero_grad()
        loss2.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm)
        self.q_func2_optimizer.step()

    def update_temperature(self, log_prob):
        assert not log_prob.requires_grad
        loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target))
        self.temperature_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm)
        self.temperature_optimizer.step()

    def update_policy_and_temperature(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]

        action_distrib = self.policy(batch_state)
        actions = action_distrib.rsample()
        log_prob = action_distrib.log_prob(actions)
        q1 = self.q_func1((batch_state, actions))
        q2 = self.q_func2((batch_state, actions))
        q = torch.min(q1, q2)

        entropy_term = self.temperature * log_prob[..., None]
        assert q.shape == entropy_term.shape
        loss = torch.mean(entropy_term - q)

        self.policy_optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm)
        self.policy_optimizer.step()

        self.n_policy_updates += 1

        if self.entropy_target is not None:
            self.update_temperature(log_prob.detach())

        # Record entropy
        with torch.no_grad():
            try:
                self.entropy_record.extend(
                    action_distrib.entropy().detach().cpu().numpy()
                )
            except NotImplementedError:
                # Record - log p(x) instead
                self.entropy_record.extend(-log_prob.detach().cpu().numpy())

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""
        batch = batch_experiences(experiences, self.device, self.phi, self.gamma)
        self.update_q_func(batch)
        self.update_policy_and_temperature(batch)
        self.sync_target_network()

    def batch_select_greedy_action(self, batch_obs, deterministic=False):
        with torch.no_grad(), pfrl.utils.evaluating(self.policy):
            batch_xs = self.batch_states(batch_obs, self.device, self.phi)
            policy_out = self.policy(batch_xs)
            if deterministic:
                batch_action = mode_of_distribution(policy_out).cpu().numpy()
            else:
                batch_action = policy_out.sample().cpu().numpy()
        return batch_action

    def batch_act(self, batch_obs):
        if self.training:
            return self._batch_act_train(batch_obs)
        else:
            return self._batch_act_eval(batch_obs)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        if self.training:
            self._batch_observe_train(batch_obs, batch_reward, batch_done, batch_reset)

    def _batch_act_eval(self, batch_obs):
        assert not self.training
        return self.batch_select_greedy_action(
            batch_obs, deterministic=self.act_deterministically
        )

    def _batch_act_train(self, batch_obs):
        assert self.training
        if self.burnin_action_func is not None and self.n_policy_updates == 0:
            batch_action = [self.burnin_action_func() for _ in range(len(batch_obs))]
        else:
            batch_action = self.batch_select_greedy_action(batch_obs)
        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def _batch_observe_train(self, batch_obs, batch_reward, batch_done, batch_reset):
        assert self.training
        for i in range(len(batch_obs)):
            self.t += 1
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def get_statistics(self):
        return [
            ("average_q1", _mean_or_nan(self.q1_record)),
            ("average_q2", _mean_or_nan(self.q2_record)),
            ("average_q_func1_loss", _mean_or_nan(self.q_func1_loss_record)),
            ("average_q_func2_loss", _mean_or_nan(self.q_func2_loss_record)),
            ("n_updates", self.n_policy_updates),
            ("average_entropy", _mean_or_nan(self.entropy_record)),
            ("temperature", self.temperature),
        ]
Exemplo n.º 6
0
    def __init__(
        self,
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        replay_buffer,
        gamma,
        gpu=None,
        replay_start_size=10000,
        minibatch_size=100,
        update_interval=1,
        phi=lambda x: x,
        soft_update_tau=5e-3,
        max_grad_norm=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
        initial_temperature=1.0,
        entropy_target=None,
        temperature_optimizer_lr=None,
        act_deterministically=True,
    ):

        self.policy = policy
        self.q_func1 = q_func1
        self.q_func2 = q_func2

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.policy.to(self.device)
            self.q_func1.to(self.device)
            self.q_func2.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.gpu = gpu
        self.phi = phi
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.policy_optimizer = policy_optimizer
        self.q_func1_optimizer = q_func1_optimizer
        self.q_func2_optimizer = q_func2_optimizer
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            n_times_update=1,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
            episodic_update=False,
        )
        self.max_grad_norm = max_grad_norm
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func
        self.initial_temperature = initial_temperature
        self.entropy_target = entropy_target
        if self.entropy_target is not None:
            self.temperature_holder = TemperatureHolder(
                initial_log_temperature=np.log(initial_temperature)
            )
            if temperature_optimizer_lr is not None:
                self.temperature_optimizer = torch.optim.Adam(
                    self.temperature_holder.parameters(), lr=temperature_optimizer_lr
                )
            else:
                self.temperature_optimizer = torch.optim.Adam(
                    self.temperature_holder.parameters()
                )
            if gpu is not None and gpu >= 0:
                self.temperature_holder.to(self.device)
        else:
            self.temperature_holder = None
            self.temperature_optimizer = None
        self.act_deterministically = act_deterministically

        self.t = 0

        # Target model
        self.target_q_func1 = copy.deepcopy(self.q_func1).eval().requires_grad_(False)
        self.target_q_func2 = copy.deepcopy(self.q_func2).eval().requires_grad_(False)

        # Statistics
        self.q1_record = collections.deque(maxlen=1000)
        self.q2_record = collections.deque(maxlen=1000)
        self.entropy_record = collections.deque(maxlen=1000)
        self.q_func1_loss_record = collections.deque(maxlen=100)
        self.q_func2_loss_record = collections.deque(maxlen=100)
        self.n_policy_updates = 0
Exemplo n.º 7
0
    def __init__(
        self,
        q_function: QNetworkWithValuebuffer,  # torch.nn.Module,
        optimizer: torch.optim.
        Optimizer,  # type: ignore  # somehow mypy complains
        replay_buffer: EVAReplayBuffer,
        gamma: float,
        explorer: Explorer,
        gpu: Optional[int] = None,
        replay_start_size: int = 50000,
        minibatch_size: int = 32,
        update_interval: int = 1,
        target_update_interval: int = 10000,
        clip_delta: bool = True,
        phi: Callable[[Any], Any] = lambda x: x,
        target_update_method: str = "hard",
        soft_update_tau: float = 1e-2,
        n_times_update: int = 1,
        batch_accumulator: str = "mean",
        episodic_update_len: Optional[int] = None,
        interval_tcp=20,
        n_trj_step=50,
        use_eva=True,  # If False, This Agent become DQN.
        logger: Logger = getLogger(__name__),
        batch_states: Callable[
            [Sequence[Any], torch.device, Callable[[Any],
                                                   Any]], Any] = batch_states,
        recurrent: bool = False,
        max_grad_norm: Optional[float] = None,
    ):
        self.model = q_function

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ("mean", "sum")
        self.logger = logger
        self.batch_states = batch_states
        # self.recurrent = recurrent
        self.recurrent = False
        self.n_actions = self.model.n_actions
        self.value_buffer = self.model.v_buffer
        self.interval_tcp = interval_tcp
        self.n_trj_step = n_trj_step
        self.use_eva = use_eva
        update_func: Callable[..., None]
        if self.recurrent:
            assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer)
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.minibatch_size = minibatch_size
        self.episodic_update_len = episodic_update_len
        self.replay_start_size = replay_start_size
        self.update_interval = update_interval
        self.max_grad_norm = max_grad_norm

        assert (
            target_update_interval % update_interval == 0
        ), "target_update_interval should be a multiple of update_interval"

        self.t = 0
        self.eval_t = 0
        self.optim_t = 0  # Compensate pytorch optim not having `t`
        self._cumulative_steps = 0
        self.target_model = make_target_model_as_copy(self.model.q_function)

        # Statistics
        self.q_record: collections.deque = collections.deque(maxlen=1000)
        self.loss_record: collections.deque = collections.deque(maxlen=100)

        # Recurrent states of the model
        self.train_recurrent_states: Any = None
        self.train_prev_recurrent_states: Any = None
        self.test_recurrent_states: Any = None

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError(
                "Replay start size cannot exceed replay buffer capacity.")
Exemplo n.º 8
0
class EVA(agent.AttributeSavingMixin, agent.BatchAgent):
    """Ephemeral Value Adjusments

    Args:
        q_function (StateQFunction): Q-function
        optimizer (Optimizer): Optimizer that is already setup
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        clip_delta (bool): Clip delta if set True
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `pfrl.utils.batch_states.batch_states`
        recurrent (bool): If set to True, `model` is assumed to implement
            `pfrl.nn.Recurrent` and is updated in a recurrent
            manner.
        max_grad_norm (float or None): Maximum L2 norm of the gradient used for
            gradient clipping. If set to None, the gradient is not clipped.
    """

    saved_attributes = ("model", "target_model", "optimizer")

    def __init__(
        self,
        q_function: QNetworkWithValuebuffer,  # torch.nn.Module,
        optimizer: torch.optim.
        Optimizer,  # type: ignore  # somehow mypy complains
        replay_buffer: EVAReplayBuffer,
        gamma: float,
        explorer: Explorer,
        gpu: Optional[int] = None,
        replay_start_size: int = 50000,
        minibatch_size: int = 32,
        update_interval: int = 1,
        target_update_interval: int = 10000,
        clip_delta: bool = True,
        phi: Callable[[Any], Any] = lambda x: x,
        target_update_method: str = "hard",
        soft_update_tau: float = 1e-2,
        n_times_update: int = 1,
        batch_accumulator: str = "mean",
        episodic_update_len: Optional[int] = None,
        interval_tcp=20,
        n_trj_step=50,
        use_eva=True,  # If False, This Agent become DQN.
        logger: Logger = getLogger(__name__),
        batch_states: Callable[
            [Sequence[Any], torch.device, Callable[[Any],
                                                   Any]], Any] = batch_states,
        recurrent: bool = False,
        max_grad_norm: Optional[float] = None,
    ):
        self.model = q_function

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ("mean", "sum")
        self.logger = logger
        self.batch_states = batch_states
        # self.recurrent = recurrent
        self.recurrent = False
        self.n_actions = self.model.n_actions
        self.value_buffer = self.model.v_buffer
        self.interval_tcp = interval_tcp
        self.n_trj_step = n_trj_step
        self.use_eva = use_eva
        update_func: Callable[..., None]
        if self.recurrent:
            assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer)
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.minibatch_size = minibatch_size
        self.episodic_update_len = episodic_update_len
        self.replay_start_size = replay_start_size
        self.update_interval = update_interval
        self.max_grad_norm = max_grad_norm

        assert (
            target_update_interval % update_interval == 0
        ), "target_update_interval should be a multiple of update_interval"

        self.t = 0
        self.eval_t = 0
        self.optim_t = 0  # Compensate pytorch optim not having `t`
        self._cumulative_steps = 0
        self.target_model = make_target_model_as_copy(self.model.q_function)

        # Statistics
        self.q_record: collections.deque = collections.deque(maxlen=1000)
        self.loss_record: collections.deque = collections.deque(maxlen=100)

        # Recurrent states of the model
        self.train_recurrent_states: Any = None
        self.train_prev_recurrent_states: Any = None
        self.test_recurrent_states: Any = None

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError(
                "Replay start size cannot exceed replay buffer capacity.")

    @property
    def cumulative_steps(self) -> int:
        # cumulative_steps counts the overall steps during the training.
        return self._cumulative_steps

    def _setup_actor_learner_training(
        self,
        n_actors: int,
        actor_update_interval: int,
        update_counter: Any,
    ) -> Tuple[torch.nn.Module, Sequence[mp.connection.Connection],
               Sequence[mp.connection.Connection], ]:
        assert actor_update_interval > 0

        self.actor_update_interval = actor_update_interval
        self.update_counter = update_counter

        # Make a copy on shared memory and share among actors and the poller
        shared_model = copy.deepcopy(self.model).cpu()
        shared_model.share_memory()

        # Pipes are used for infrequent communication
        learner_pipes, actor_pipes = list(
            zip(*[mp.Pipe() for _ in range(n_actors)]))

        return (shared_model, learner_pipes, actor_pipes)

    def sync_target_network(self) -> None:
        """Synchronize target network with current network."""
        synchronize_parameters(
            src=self.model.q_function,
            dst=self.target_model,
            method=self.target_update_method,
            tau=self.soft_update_tau,
        )

    def update(self,
               experiences: List[List[Dict[str, Any]]],
               errors_out: Optional[list] = None) -> None:
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None
        """
        has_weight = "weight" in experiences[0][0]
        exp_batch = batch_experiences(
            experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        if has_weight:
            exp_batch["weights"] = torch.tensor(
                [elem[0]["weight"] for elem in experiences],
                device=self.device,
                dtype=torch.float32,
            )
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            assert isinstance(self.replay_buffer, PrioritizedReplayBuffer)
            self.replay_buffer.update_errors(errors_out)

        self.loss_record.append(float(loss.detach().cpu().numpy()))

        self.optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            pfrl.utils.clip_l2_grad_norm_(self.model.parameters(),
                                          self.max_grad_norm)
        self.optimizer.step()
        self.optim_t += 1

    def update_from_episodes(self,
                             episodes: List[List[Dict[str, Any]]],
                             errors_out: Optional[list] = None) -> None:
        assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer"
        episodes = sorted(episodes, key=len, reverse=True)
        exp_batch = batch_recurrent_experiences(
            episodes,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        loss = self._compute_loss(exp_batch, errors_out=None)
        self.loss_record.append(float(loss.detach().cpu().numpy()))
        self.optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            pfrl.utils.clip_l2_grad_norm_(self.model.parameters(),
                                          self.max_grad_norm)
        self.optimizer.step()
        self.optim_t += 1

    def _compute_target_values(self, exp_batch: Dict[str,
                                                     Any]) -> torch.Tensor:
        batch_next_state = exp_batch["next_state"]

        if self.recurrent:
            target_next_qout, _ = pack_and_forward(
                self.target_model,
                batch_next_state,
                exp_batch["next_recurrent_state"],
            )
        else:
            target_next_qout, _ = self.target_model(batch_next_state)
        next_q_max = target_next_qout.max

        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]
        discount = exp_batch["discount"]

        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max

    def _compute_y_and_t(
            self, exp_batch: Dict[str,
                                  Any]) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size = exp_batch["reward"].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch["state"]

        if self.recurrent:
            qout, _ = pack_and_forward(self.model, batch_state,
                                       exp_batch["recurrent_state"])
        else:
            qout, _ = self.model(batch_state)

        batch_actions = exp_batch["action"]
        batch_q = torch.reshape(qout.evaluate_actions(batch_actions),
                                (batch_size, 1))

        with torch.no_grad():
            batch_q_target = torch.reshape(
                self._compute_target_values(exp_batch), (batch_size, 1))

        return batch_q, batch_q_target

    def _compute_loss(self,
                      exp_batch: Dict[str, Any],
                      errors_out: Optional[list] = None) -> torch.Tensor:
        """Compute the Q-learning loss for a batch of experiences


        Args:
          exp_batch (dict): A dict of batched arrays of transitions
        Returns:
          Computed loss from the minibatch of experiences
        """
        y, t = self._compute_y_and_t(exp_batch)

        self.q_record.extend(y.detach().cpu().numpy().ravel())

        if errors_out is not None:
            del errors_out[:]
            delta = torch.abs(y - t)
            if delta.ndim == 2:
                delta = torch.sum(delta, dim=1)
            delta = delta.detach().cpu().numpy()
            for e in delta:
                errors_out.append(e)

        if "weights" in exp_batch:
            return compute_weighted_value_loss(
                y,
                t,
                exp_batch["weights"],
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator,
            )
        else:
            return compute_value_loss(
                y,
                t,
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator,
            )

    def _evaluate_model_and_update_recurrent_states(self,
                                                    batch_obs: Sequence[Any]):
        batch_xs = self.batch_states(batch_obs, self.device, self.phi)
        batch_h = None
        if self.recurrent:
            if self.training:
                self.train_prev_recurrent_states = self.train_recurrent_states
                batch_av, self.train_recurrent_states = one_step_forward(
                    self.model, batch_xs, self.train_recurrent_states)
            else:
                batch_av, self.test_recurrent_states = one_step_forward(
                    self.model, batch_xs, self.test_recurrent_states)
        else:
            batch_av, batch_h = self.model(batch_xs, self.use_eva)
        return batch_av, batch_h

    def batch_act(self, batch_obs: Sequence[Any]) -> Sequence[Any]:
        with torch.no_grad(), evaluating(self.model):
            batch_av, self.batch_h = self._evaluate_model_and_update_recurrent_states(
                batch_obs)
            batch_argmax = batch_av.greedy_actions.detach().cpu().numpy()
        if self.training:
            batch_action = [
                self.explorer.select_action(
                    self.t,
                    lambda: batch_argmax[i],
                    action_value=batch_av[i:i + 1],
                ) for i in range(len(batch_obs))
            ]
            self.batch_last_obs = list(batch_obs)
            self.batch_last_action = list(batch_action)
        else:
            batch_action = batch_argmax
        return batch_action

    def _batch_observe_train(
        self,
        batch_obs: Sequence[Any],
        batch_reward: Sequence[float],
        batch_done: Sequence[bool],
        batch_reset: Sequence[bool],
    ) -> None:

        for i in range(len(batch_obs)):
            self.t += 1
            self._cumulative_steps += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    "state": self.batch_last_obs[i],
                    "action": self.batch_last_action[i],
                    "reward": batch_reward[i],
                    "feature": self.batch_h[i],
                    "next_state": batch_obs[i],
                    "next_action": None,
                    "is_state_terminal": batch_done[i],
                }
                if self.recurrent:
                    transition["recurrent_state"] = recurrent_state_as_numpy(
                        get_recurrent_state_at(
                            self.train_prev_recurrent_states, i, detach=True))
                    transition[
                        "next_recurrent_state"] = recurrent_state_as_numpy(
                            get_recurrent_state_at(self.train_recurrent_states,
                                                   i,
                                                   detach=True))
                self.replay_buffer.append(env_id=i, **transition)

                self._backup_if_necessary(self.t, self.batch_h[i])

                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.train_prev_recurrent_states = None
            self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.train_recurrent_states,
            )

    def _batch_observe_eval(
        self,
        batch_obs: Sequence[Any],
        batch_reward: Sequence[float],
        batch_done: Sequence[bool],
        batch_reset: Sequence[bool],
    ) -> None:

        for i in range(len(batch_obs)):
            self._backup_if_necessary(self.eval_t, self.batch_h[i])
            if batch_reset[i] or batch_done[i]:
                self.eval_t = 0
            else:
                self.eval_t += 1

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.test_recurrent_states,
            )

    def batch_observe(
        self,
        batch_obs: Sequence[Any],
        batch_reward: Sequence[float],
        batch_done: Sequence[bool],
        batch_reset: Sequence[bool],
    ) -> None:
        if self.training:
            return self._batch_observe_train(batch_obs, batch_reward,
                                             batch_done, batch_reset)
        else:
            return self._batch_observe_eval(batch_obs, batch_reward,
                                            batch_done, batch_reset)

    def _backup_if_necessary(self, t, feature):
        if (t % self.interval_tcp == 0
                and len(self.replay_buffer) >= self.replay_buffer.capacity
                and self.use_eva):
            trajectory_list = self.replay_buffer.lookup(
                feature, self.n_trj_step)
            batch_trj = [
                batch_trajectory(trajectory,
                                 self.device,
                                 self.phi,
                                 batch_states=batch_states)
                for trajectory in trajectory_list
            ]
            q_np_arr = self._trajectory_centric_planning(batch_trj)
            batch_feature = [
                elem for trj in batch_trj for elem in trj['feature']
            ]
            batch_feature = torch.tensor(np.asarray(batch_feature),
                                         dtype=torch.float32)
            self.value_buffer.store(batch_feature, q_np_arr)

    def _trajectory_centric_planning(self, trajectories):
        state_shape = tuple(
            trajectories[0]["state"].shape)[1:]  # torch.Size -> tuple
        # Aligning Shapes for Parallel Processing with GPUs
        # If Atari, it will be (0, 4, 84, 84)
        batch_states = torch.empty((0, ) + state_shape, dtype=torch.float32)
        for trajectory in trajectories:
            bs = torch.empty((self.n_trj_step, ) + state_shape,
                             dtype=torch.float32)
            bs[:len(trajectory["state"])] = trajectory["state"]
            # numpy.vstack
            batch_states = torch.cat((batch_states, bs), dim=0)

        batch_states = batch_states.to(self.device)
        with torch.no_grad(), evaluating(self.model):
            batch_q, _ = self.model(batch_states)
            q_theta_arr = batch_q.q_values.cpu()
            q_theta_arr = q_theta_arr.reshape(
                (len(trajectories), self.n_trj_step, self.n_actions))

        q_np_arr = torch.empty((0, self.n_actions), dtype=torch.float32)
        for q_np, trajectory in zip(q_theta_arr, trajectories):
            # batch_state = trajectory['state']
            batch_action = trajectory['action']
            batch_reward = trajectory['reward']

            q_np = q_np[:len(batch_action)]
            for t in range(len(batch_action) - 2, -1, -1):  # t:= T-2, 0
                V_np = torch.max(q_np[t +
                                      1])  # V_NP(s_t+1) := max_a Q(s_t+1, a)
                q_np[t, batch_action[t]] = batch_reward[t] + self.gamma * V_np

            q_np_arr = torch.cat((q_np_arr, q_np.reshape(-1, self.n_actions)),
                                 dim=0)

        return q_np_arr.to(self.device)

    def _can_start_replay(self) -> bool:
        if len(self.replay_buffer) < self.replay_start_size:
            return False
        if self.recurrent:
            assert isinstance(self.replay_buffer, AbstractEpisodicReplayBuffer)
            if self.replay_buffer.n_episodes < self.minibatch_size:
                return False
        return True

    def _poll_pipe(
        self,
        actor_idx: int,
        pipe: mp.connection.Connection,
        replay_buffer_lock: mp.synchronize.Lock,
        exception_event: mp.synchronize.Event,
    ) -> None:
        if pipe.closed:
            return
        try:
            while pipe.poll() and not exception_event.is_set():
                cmd, data = pipe.recv()
                if cmd == "get_statistics":
                    assert data is None
                    with replay_buffer_lock:
                        stats = self.get_statistics()
                    pipe.send(stats)
                elif cmd == "load":
                    self.load(data)
                    pipe.send(None)
                elif cmd == "save":
                    self.save(data)
                    pipe.send(None)
                elif cmd == "transition":
                    with replay_buffer_lock:
                        if "env_id" not in data:
                            data["env_id"] = actor_idx
                        self.replay_buffer.append(**data)
                        self._cumulative_steps += 1
                elif cmd == "stop_episode":
                    idx = actor_idx if data is None else data
                    with replay_buffer_lock:
                        self.replay_buffer.stop_current_episode(env_id=idx)
                        stats = self.get_statistics()
                    pipe.send(stats)

                else:
                    raise RuntimeError(
                        "Unknown command from actor: {}".format(cmd))
        except EOFError:
            pipe.close()
        except Exception:
            self.logger.exception("Poller loop failed. Exiting")
            exception_event.set()

    def _learner_loop(
        self,
        shared_model: torch.nn.Module,
        pipes: Sequence[mp.connection.Connection],
        replay_buffer_lock: mp.synchronize.Lock,
        stop_event: mp.synchronize.Event,
        exception_event: mp.synchronize.Event,
        n_updates: Optional[int] = None,
    ) -> None:
        try:
            update_counter = 0
            # To stop this loop, call stop_event.set()
            while not stop_event.is_set():
                # Update model if possible
                if not self._can_start_replay():
                    continue
                if n_updates is not None:
                    assert self.optim_t <= n_updates
                    if self.optim_t == n_updates:
                        stop_event.set()
                        break

                if self.recurrent:
                    assert isinstance(self.replay_buffer,
                                      AbstractEpisodicReplayBuffer)
                    with replay_buffer_lock:
                        episodes = self.replay_buffer.sample_episodes(
                            self.minibatch_size, self.episodic_update_len)
                    self.update_from_episodes(episodes)
                else:
                    with replay_buffer_lock:
                        transitions = self.replay_buffer.sample(
                            self.minibatch_size)
                    self.update(transitions)

                # Update the shared model. This can be expensive if GPU is used
                # since this is a DtoH copy, so it is updated only at regular
                # intervals.
                update_counter += 1
                if update_counter % self.actor_update_interval == 0:
                    with self.update_counter.get_lock():
                        self.update_counter.value += 1
                        shared_model.load_state_dict(self.model.state_dict())

                # To keep the ratio of target updates to model updates,
                # here we calculate back the effective current timestep
                # from update_interval and number of updates so far.
                effective_timestep = self.optim_t * self.update_interval
                # We can safely assign self.t since in the learner
                # it isn't updated by any other method
                self.t = effective_timestep
                if effective_timestep % self.target_update_interval == 0:
                    self.sync_target_network()
        except Exception:
            self.logger.exception("Learner loop failed. Exiting")
            exception_event.set()

    def _poller_loop(
        self,
        shared_model: torch.nn.Module,
        pipes: Sequence[mp.connection.Connection],
        replay_buffer_lock: mp.synchronize.Lock,
        stop_event: mp.synchronize.Event,
        exception_event: mp.synchronize.Event,
    ) -> None:
        # To stop this loop, call stop_event.set()
        while not stop_event.is_set() and not exception_event.is_set():
            time.sleep(1e-6)
            # Poll actors for messages
            for i, pipe in enumerate(pipes):
                self._poll_pipe(i, pipe, replay_buffer_lock, exception_event)

    def setup_actor_learner_training(
        self,
        n_actors: int,
        update_counter: Optional[Any] = None,
        n_updates: Optional[int] = None,
        actor_update_interval: int = 8,
    ):
        if update_counter is None:
            update_counter = mp.Value(ctypes.c_ulong)

        (shared_model, learner_pipes,
         actor_pipes) = self._setup_actor_learner_training(
             n_actors, actor_update_interval, update_counter)
        exception_event = mp.Event()

        def make_actor(i):
            return pfrl.agents.StateQFunctionActor(
                pipe=actor_pipes[i],
                model=shared_model,
                explorer=self.explorer,
                phi=self.phi,
                batch_states=self.batch_states,
                logger=self.logger,
                recurrent=self.recurrent,
            )

        replay_buffer_lock = mp.Lock()

        poller_stop_event = mp.Event()
        poller = pfrl.utils.StoppableThread(
            target=self._poller_loop,
            kwargs=dict(
                shared_model=shared_model,
                pipes=learner_pipes,
                replay_buffer_lock=replay_buffer_lock,
                stop_event=poller_stop_event,
                exception_event=exception_event,
            ),
            stop_event=poller_stop_event,
        )

        learner_stop_event = mp.Event()
        learner = pfrl.utils.StoppableThread(
            target=self._learner_loop,
            kwargs=dict(
                shared_model=shared_model,
                pipes=learner_pipes,
                replay_buffer_lock=replay_buffer_lock,
                stop_event=learner_stop_event,
                n_updates=n_updates,
                exception_event=exception_event,
            ),
            stop_event=learner_stop_event,
        )

        return make_actor, learner, poller, exception_event

    def stop_episode(self) -> None:
        if self.recurrent:
            self.test_recurrent_states = None

    def get_statistics(self):
        return [
            ("average_q", _mean_or_nan(self.q_record)),
            ("average_loss", _mean_or_nan(self.loss_record)),
            ("cumulative_steps", self.cumulative_steps),
            ("n_updates", self.optim_t),
            ("rlen", len(self.replay_buffer)),
        ]
Exemplo n.º 9
0
    def __init__(
        self,
        policy,
        q_func,
        actor_optimizer,
        critic_optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=50000,
        minibatch_size=32,
        update_interval=1,
        target_update_interval=10000,
        phi=lambda x: x,
        target_update_method="hard",
        soft_update_tau=1e-2,
        n_times_update=1,
        recurrent=False,
        episodic_update_len=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
    ):

        self.model = nn.ModuleList([policy, q_func])
        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.recurrent = recurrent
        assert not self.recurrent, "recurrent=True is not yet implemented"
        if self.recurrent:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = copy.deepcopy(self.model)
        self.target_model.eval()
        self.q_record = collections.deque(maxlen=1000)
        self.actor_loss_record = collections.deque(maxlen=100)
        self.critic_loss_record = collections.deque(maxlen=100)
        self.n_updates = 0

        # Aliases for convenience
        self.policy, self.q_function = self.model
        self.target_policy, self.target_q_function = self.target_model

        self.sync_target_network()
Exemplo n.º 10
0
class DDPG(AttributeSavingMixin, BatchAgent):
    """Deep Deterministic Policy Gradients.

    This can be used as SVG(0) by specifying a Gaussian policy instead of a
    deterministic policy.

    Args:
        policy (torch.nn.Module): Policy
        q_func (torch.nn.Module): Q-function
        actor_optimizer (Optimizer): Optimizer setup with the policy
        critic_optimizer (Optimizer): Optimizer setup with the Q-function
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update (bool): Use full episodes for update if set True
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `pfrl.utils.batch_states.batch_states`
        burnin_action_func (callable or None): If not None, this callable
            object is used to select actions before the model is updated
            one or more times during training.
    """

    saved_attributes = ("model", "target_model", "actor_optimizer",
                        "critic_optimizer")

    def __init__(
        self,
        policy,
        q_func,
        actor_optimizer,
        critic_optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=50000,
        minibatch_size=32,
        update_interval=1,
        target_update_interval=10000,
        phi=lambda x: x,
        target_update_method="hard",
        soft_update_tau=1e-2,
        n_times_update=1,
        recurrent=False,
        episodic_update_len=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
    ):

        self.model = nn.ModuleList([policy, q_func])
        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.recurrent = recurrent
        assert not self.recurrent, "recurrent=True is not yet implemented"
        if self.recurrent:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = copy.deepcopy(self.model)
        self.target_model.eval()
        self.q_record = collections.deque(maxlen=1000)
        self.actor_loss_record = collections.deque(maxlen=100)
        self.critic_loss_record = collections.deque(maxlen=100)
        self.n_updates = 0

        # Aliases for convenience
        self.policy, self.q_function = self.model
        self.target_policy, self.target_q_function = self.target_model

        self.sync_target_network()

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(
            src=self.model,
            dst=self.target_model,
            method=self.target_update_method,
            tau=self.soft_update_tau,
        )

    # Update Q-function
    def compute_critic_loss(self, batch):
        """Compute loss for critic."""

        batch_next_state = batch["next_state"]
        batch_rewards = batch["reward"]
        batch_terminal = batch["is_state_terminal"]
        batch_state = batch["state"]
        batch_actions = batch["action"]
        batchsize = len(batch_rewards)

        with torch.no_grad():
            assert not self.recurrent
            next_actions = self.target_policy(batch_next_state).sample()
            next_q = self.target_q_function((batch_next_state, next_actions))
            target_q = batch_rewards + self.gamma * (
                1.0 - batch_terminal) * next_q.reshape((batchsize, ))

        predict_q = self.q_function((batch_state, batch_actions)).reshape(
            (batchsize, ))

        loss = F.mse_loss(target_q, predict_q)

        # Update stats
        self.critic_loss_record.append(float(loss.detach().cpu().numpy()))

        return loss

    def compute_actor_loss(self, batch):
        """Compute loss for actor."""

        batch_state = batch["state"]
        onpolicy_actions = self.policy(batch_state).rsample()
        q = self.q_function((batch_state, onpolicy_actions))
        loss = -q.mean()

        # Update stats
        self.q_record.extend(q.detach().cpu().numpy())
        self.actor_loss_record.append(float(loss.detach().cpu().numpy()))

        return loss

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.device, self.phi,
                                  self.gamma)

        self.critic_optimizer.zero_grad()
        self.compute_critic_loss(batch).backward()
        self.critic_optimizer.step()

        self.actor_optimizer.zero_grad()
        self.compute_actor_loss(batch).backward()
        self.actor_optimizer.step()

        self.n_updates += 1

    def update_from_episodes(self, episodes, errors_out=None):
        raise NotImplementedError

        # Sort episodes desc by their lengths
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        # Precompute all the input batches
        batches = []
        for i in range(max_epi_len):
            transitions = []
            for ep in sorted_episodes:
                if len(ep) <= i:
                    break
                transitions.append([ep[i]])
            batch = batch_experiences(transitions,
                                      xp=self.device,
                                      phi=self.phi,
                                      gamma=self.gamma)
            batches.append(batch)

        with self.model.state_reset(), self.target_model.state_reset():

            # Since the target model is evaluated one-step ahead,
            # its internal states need to be updated
            self.target_q_function.update_state(batches[0]["state"],
                                                batches[0]["action"])
            self.target_policy(batches[0]["state"])

            # Update critic through time
            critic_loss = 0
            for batch in batches:
                critic_loss += self.compute_critic_loss(batch)
            self.critic_optimizer.update(lambda: critic_loss / max_epi_len)

        with self.model.state_reset():

            # Update actor through time
            actor_loss = 0
            for batch in batches:
                actor_loss += self.compute_actor_loss(batch)
            self.actor_optimizer.update(lambda: actor_loss / max_epi_len)

    def batch_act(self, batch_obs):
        if self.training:
            return self._batch_act_train(batch_obs)
        else:
            return self._batch_act_eval(batch_obs)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        if self.training:
            self._batch_observe_train(batch_obs, batch_reward, batch_done,
                                      batch_reset)

    def _batch_select_greedy_actions(self, batch_obs):
        with torch.no_grad(), evaluating(self.policy):
            batch_xs = self.batch_states(batch_obs, self.device, self.phi)
            batch_action = self.policy(batch_xs).sample()
            return batch_action.cpu().numpy()

    def _batch_act_eval(self, batch_obs):
        assert not self.training
        return self._batch_select_greedy_actions(batch_obs)

    def _batch_act_train(self, batch_obs):
        assert self.training
        if self.burnin_action_func is not None and self.n_updates == 0:
            batch_action = [
                self.burnin_action_func() for _ in range(len(batch_obs))
            ]
        else:
            batch_greedy_action = self._batch_select_greedy_actions(batch_obs)
            batch_action = [
                self.explorer.select_action(self.t,
                                            lambda: batch_greedy_action[i])
                for i in range(len(batch_greedy_action))
            ]

        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def _batch_observe_train(self, batch_obs, batch_reward, batch_done,
                             batch_reset):
        assert self.training
        for i in range(len(batch_obs)):
            self.t += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def get_statistics(self):
        return [
            ("average_q", _mean_or_nan(self.q_record)),
            ("average_actor_loss", _mean_or_nan(self.actor_loss_record)),
            ("average_critic_loss", _mean_or_nan(self.critic_loss_record)),
            ("n_updates", self.n_updates),
        ]
Exemplo n.º 11
0
class SQIL(agent.AttributeSavingMixin, agent.BatchAgent):
    """Deep Q-Network algorithm.

    Args:
        q_function (StateQFunction): Q-function
        optimizer (Optimizer): Optimizer that is already setup
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        clip_delta (bool): Clip delta if set True
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `pfrl.utils.batch_states.batch_states`
        recurrent (bool): If set to True, `model` is assumed to implement
            `pfrl.nn.Recurrent` and is updated in a recurrent
            manner.

        Changes from DQN:
            remove recurrent support
            add expert dataset
    """

    saved_attributes = ("model", "target_model", "optimizer")

    def __init__(
            self,
            q_function,
            optimizer,
            replay_buffer,
            gamma,
            explorer,
            gpu=None,
            replay_start_size=50000,
            minibatch_size=32,
            update_interval=1,
            target_update_interval=10000,
            clip_delta=True,
            phi=lambda x: x,
            target_update_method="hard",
            soft_update_tau=1e-2,
            n_times_update=1,
            batch_accumulator="mean",
            episodic_update_len=None,
            logger=getLogger(__name__),
            batch_states=batch_states,
            expert_dataset=None,
            reward_scale=1.0,
            experience_lambda=1.0,
            recurrent=False,
            reward_boundaries=None,  # specific to options
    ):
        self.expert_dataset = expert_dataset

        self.model = q_function

        if gpu is not None and gpu >= 0:
            assert torch.cuda.is_available()
            self.device = torch.device("cuda:{}".format(gpu))
            self.model.to(self.device)
        else:
            self.device = torch.device("cpu")

        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ("mean", "sum")
        self.logger = logger
        self.batch_states = batch_states
        self.recurrent = recurrent
        if self.recurrent:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.minibatch_size = minibatch_size
        self.episodic_update_len = episodic_update_len
        self.replay_start_size = replay_start_size
        self.update_interval = update_interval

        assert (
            target_update_interval % update_interval == 0
        ), "target_update_interval should be a multiple of update_interval"

        # For imitation
        self.reward_scale = reward_scale
        self.experience_lambda = experience_lambda

        if reward_boundaries is not None and self.expert_dataset is not None:
            self.reward_based_sampler = RewardBasedSampler(self.expert_dataset,
                                                           reward_boundaries,
                                                           reward=reward_scale)
        else:
            self.reward_based_sampler = None

        self.t = 0
        self.optim_t = 0  # Compensate pytorch optim not having `t`
        self._cumulative_steps = 0
        self.last_state = None
        self.last_action = None
        self.target_model = None
        self.sync_target_network()

        # Statistics
        self.q_record = collections.deque(maxlen=1000)
        self.loss_record = collections.deque(maxlen=100)

        # Recurrent states of the model
        self.train_recurrent_states = None
        self.train_prev_recurrent_states = None
        self.test_recurrent_states = None

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError(
                "Replay start size cannot exceed replay buffer capacity.")

    @property
    def cumulative_steps(self):
        # cumulative_steps counts the overall steps during the training.
        return self._cumulative_steps

    def sync_target_network(self):
        """Synchronize target network with current network."""
        if self.target_model is None:
            self.target_model = copy.deepcopy(self.model)

            def flatten_parameters(mod):
                if isinstance(mod, torch.nn.RNNBase):
                    mod.flatten_parameters()

            # RNNBase.flatten_parameters must be called again after deep-copy.
            # See: https://discuss.pytorch.org/t/why-do-we-need-flatten-parameters-when-using-rnn-with-dataparallel/46506  # NOQA
            self.target_model.apply(flatten_parameters)
            # set target n/w to evaluate only.
            self.target_model.eval()
        else:
            synchronize_parameters(
                src=self.model,
                dst=self.target_model,
                method=self.target_update_method,
                tau=self.soft_update_tau,
            )

    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None

        Changes from DQN:
            Learned from demonstrations
        """
        has_weight = "weight" in experiences[0][0]
        exp_batch = batch_experiences(
            experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        if has_weight:
            exp_batch["weights"] = torch.tensor(
                [elem[0]["weight"] for elem in experiences],
                device=self.device,
                dtype=torch.float32,
            )
            if errors_out is None:
                errors_out = []

        if self.reward_based_sampler is not None:
            demo_experiences = self.reward_based_sampler.sample(experiences)
        else:
            demo_experiences = load_experiences_from_demonstrations(
                self.expert_dataset, self.replay_updater.batchsize,
                self.reward_scale)
        demo_batch = batch_experiences(
            demo_experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        loss = self._compute_loss(exp_batch, demo_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        self.loss_record.append(float(loss.detach().cpu().numpy()))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.optim_t += 1

    def update_from_episodes(self, episodes, errors_out=None):
        assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer"
        episodes = sorted(episodes, key=len, reverse=True)
        exp_batch = batch_recurrent_experiences(
            episodes,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        demo_experiences = load_experiences_from_demonstrations(
            self.expert_dataset, self.replay_updater.batchsize,
            self.reward_scale)
        demo_batch = batch_experiences(
            demo_experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        loss = self._compute_loss(exp_batch, demo_batch, errors_out=None)
        self.loss_record.append(float(loss.detach().cpu().numpy()))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.optim_t += 1

    def _compute_target_values(self, exp_batch):
        """
        Changes from DQN:
            Consider soft Bellman error
        """
        batch_next_state = exp_batch["next_state"]

        target_next_qout = self.target_model(batch_next_state)

        next_q_max = torch.broadcast_tensors(
            target_next_qout.q_values.max(dim=-1, keepdim=True)[0],
            target_next_qout.q_values)[0]
        next_q_soft = (
            next_q_max[:, 0] +
            (target_next_qout.q_values - next_q_max).exp().sum(dim=-1).log())

        batch_rewards = exp_batch["reward"]
        batch_terminal = exp_batch["is_state_terminal"]
        discount = exp_batch["discount"]

        # return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max
        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_soft

    def _compute_y_and_t(self, exp_batch):
        batch_size = exp_batch["reward"].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch["state"]

        if self.recurrent:
            qout, _ = pack_and_forward(self.model, batch_state,
                                       exp_batch["recurrent_state"])
        else:
            qout = self.model(batch_state)

        batch_actions = exp_batch["action"]
        batch_q = torch.reshape(qout.evaluate_actions(batch_actions),
                                (batch_size, 1))

        with torch.no_grad():
            batch_q_target = torch.reshape(
                self._compute_target_values(exp_batch), (batch_size, 1))

        return batch_q, batch_q_target

    def __compute_loss(self, exp_batch, errors_out):
        y, t = self._compute_y_and_t(exp_batch)

        self.q_record.extend(y.detach().cpu().numpy().ravel())

        if errors_out is not None:
            del errors_out[:]
            delta = torch.abs(y - t)
            if delta.ndim == 2:
                delta = torch.sum(delta, dim=1)
            delta = delta.detach().cpu().numpy()
            for e in delta:
                errors_out.append(e)

        if "weights" in exp_batch:
            return compute_weighted_value_loss(
                y,
                t,
                exp_batch["weights"],
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator,
            )
        else:
            return compute_value_loss(
                y,
                t,
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator,
            )

    def _compute_loss(self, exp_batch, demo_batch, errors_out=None):
        """Compute the Q-learning loss for a batch of experiences


        Args:
          exp_batch (dict): A dict of batched arrays of transitions
        Returns:
          Computed loss from the minibatch of experiences

        Changes from DQN:
            Learned from demonstrations
        """
        exp_loss = self.__compute_loss(exp_batch, errors_out=errors_out)
        demo_loss = self.__compute_loss(demo_batch, errors_out=None)
        return (exp_loss * self.experience_lambda + demo_loss) / 2

    def _evaluate_model_and_update_recurrent_states(self, batch_obs):
        batch_xs = self.batch_states(batch_obs, self.device, self.phi)
        if self.recurrent:
            if self.training:
                self.train_prev_recurrent_states = self.train_recurrent_states
                batch_av, self.train_recurrent_states = one_step_forward(
                    self.model, batch_xs, self.train_recurrent_states)
            else:
                batch_av, self.test_recurrent_states = one_step_forward(
                    self.model, batch_xs, self.test_recurrent_states)
        else:
            batch_av = self.model(batch_xs)
        return batch_av

    def batch_act(self, batch_obs):
        with torch.no_grad(), evaluating(self.model):
            batch_av = self._evaluate_model_and_update_recurrent_states(
                batch_obs)
            batch_argmax = batch_av.greedy_actions.cpu().numpy()
        if self.training:
            batch_action = [
                self.explorer.select_action(
                    self.t,
                    lambda: batch_argmax[i],
                    action_value=batch_av[i:i + 1],
                ) for i in range(len(batch_obs))
            ]
            self.batch_last_obs = list(batch_obs)
            self.batch_last_action = list(batch_action)
        else:
            # stochastic
            batch_action = [
                self.explorer.select_action(
                    self.t,
                    lambda: batch_argmax[i],
                    action_value=batch_av[i:i + 1],
                ) for i in range(len(batch_obs))
            ]
            # deterministic
            # batch_action = batch_argmax
        return batch_action

    def _batch_observe_train(self, batch_obs, batch_reward, batch_done,
                             batch_reset):

        for i in range(len(batch_obs)):
            self.t += 1
            self._cumulative_steps += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    "state": self.batch_last_obs[i],
                    "action": self.batch_last_action[i],
                    "reward": batch_reward[i],
                    "next_state": batch_obs[i],
                    "next_action": None,
                    "is_state_terminal": batch_done[i],
                }
                if self.recurrent:
                    transition["recurrent_state"] = recurrent_state_as_numpy(
                        get_recurrent_state_at(
                            self.train_prev_recurrent_states, i, detach=True))
                    transition[
                        "next_recurrent_state"] = recurrent_state_as_numpy(
                            get_recurrent_state_at(self.train_recurrent_states,
                                                   i,
                                                   detach=True))
                self.replay_buffer.append(env_id=i, **transition)
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.train_prev_recurrent_states = None
            self.train_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.train_recurrent_states,
            )

    def _batch_observe_eval(self, batch_obs, batch_reward, batch_done,
                            batch_reset):
        if self.recurrent:
            # Reset recurrent states when episodes end
            self.test_recurrent_states = _batch_reset_recurrent_states_when_episodes_end(  # NOQA
                batch_done=batch_done,
                batch_reset=batch_reset,
                recurrent_states=self.test_recurrent_states,
            )

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        if self.training:
            return self._batch_observe_train(batch_obs, batch_reward,
                                             batch_done, batch_reset)
        else:
            return self._batch_observe_eval(batch_obs, batch_reward,
                                            batch_done, batch_reset)

    def _can_start_replay(self):
        if len(self.replay_buffer) < self.replay_start_size:
            return False
        if self.recurrent and self.replay_buffer.n_episodes < self.minibatch_size:
            return False
        return True

    def stop_episode(self):
        if self.recurrent:
            self.test_recurrent_states = None

    def get_statistics(self):
        return [
            ("average_q", _mean_or_nan(self.q_record)),
            ("average_loss", _mean_or_nan(self.loss_record)),
            ("cumulative_steps", self.cumulative_steps),
            ("n_updates", self.optim_t),
            ("rlen", len(self.replay_buffer)),
        ]