Example #1
0
class DDPG(AttributeSavingMixin, Agent):
    """Deep Deterministic Policy Gradients.

    This can be used as SVG(0) by specifying a Gaussina policy instead of a
    deterministic policy.

    Args:
        model (DDPGModel): DDPG model that contains both a policy and a
            Q-function
        actor_optimizer (Optimizer): Optimizer setup with the policy
        critic_optimizer (Optimizer): Optimizer setup with the Q-function
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_frequency (int): Model update frequency in step
        target_update_frequency (int): Target model update frequency in step
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        average_q_decay (float): Decay rate of average Q, only used for
            recording statistics
        average_loss_decay (float): Decay rate of average loss, only used for
            recording statistics
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update (bool): Use full episodes for update if set True
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
    """

    saved_attributes = ('model', 'target_model', 'actor_optimizer',
                        'critic_optimizer')

    def __init__(self,
                 model,
                 actor_optimizer,
                 critic_optimizer,
                 replay_buffer,
                 gamma,
                 explorer,
                 gpu=None,
                 replay_start_size=50000,
                 minibatch_size=32,
                 update_frequency=1,
                 target_update_frequency=10000,
                 phi=lambda x: x,
                 target_update_method='hard',
                 soft_update_tau=1e-2,
                 n_times_update=1,
                 average_q_decay=0.999,
                 average_loss_decay=0.99,
                 episodic_update=False,
                 episodic_update_len=None,
                 logger=getLogger(__name__),
                 batch_states=batch_states):

        self.model = model

        if gpu is not None and gpu >= 0:
            cuda.get_device(gpu).use()
            self.model.to_gpu(device=gpu)

        self.xp = self.model.xp
        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_frequency = target_update_frequency
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.average_q_decay = average_q_decay
        self.average_loss_decay = average_loss_decay
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        if episodic_update:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=episodic_update,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_frequency=update_frequency,
        )
        self.batch_states = batch_states

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = copy.deepcopy(self.model)
        self.average_q = 0
        self.average_actor_loss = 0.0
        self.average_critic_loss = 0.0

        # Aliases for convenience
        self.q_function = self.model['q_function']
        self.policy = self.model['policy']
        self.target_q_function = self.target_model['q_function']
        self.target_policy = self.target_model['policy']

        self.sync_target_network()

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(src=self.model,
                               dst=self.target_model,
                               method=self.target_update_method,
                               tau=self.soft_update_tau)

    # Update Q-function
    def compute_critic_loss(self, batch):
        """Compute loss for critic.

        Preconditions:
          target_q_function must have seen up to s_t and a_t.
          target_policy must have seen up to s_t.
          q_function must have seen up to s_{t-1}.
        Postconditions:
          target_q_function must have seen up to s_{t+1} and a_{t+1}.
          target_policy must have seen up to s_{t+1}.
          q_function must have seen up to s_t.
        """

        batch_next_state = batch['next_state']
        batch_rewards = batch['reward']
        batch_terminal = batch['is_state_terminal']
        batch_state = batch['state']
        batch_actions = batch['action']
        batch_next_actions = batch['next_action']
        batchsize = len(batch_rewards)

        with chainer.no_backprop_mode():
            # Target policy observes s_{t+1}
            next_actions = self.target_policy(batch_next_state,
                                              test=True).sample()

            # Q(s_{t+1}, mu(a_{t+1})) is evaluated.
            # This should not affect the internal state of Q.
            with state_kept(self.target_q_function):
                next_q = self.target_q_function(batch_next_state,
                                                next_actions,
                                                test=True)

            # Target Q-function observes s_{t+1} and a_{t+1}
            if isinstance(self.target_q_function, Recurrent):
                self.target_q_function.update_state(batch_next_state,
                                                    batch_next_actions,
                                                    test=True)

            target_q = batch_rewards + self.gamma * \
                (1.0 - batch_terminal) * F.reshape(next_q, (batchsize,))

        # Estimated Q-function observes s_t and a_t
        predict_q = F.reshape(
            self.q_function(batch_state, batch_actions, test=False),
            (batchsize, ))

        loss = F.mean_squared_error(target_q, predict_q)

        # Update stats
        self.average_critic_loss *= self.average_loss_decay
        self.average_critic_loss += ((1 - self.average_loss_decay) *
                                     float(loss.data))

        return loss

    def compute_actor_loss(self, batch):
        """Compute loss for actor.

        Preconditions:
          q_function must have seen up to s_{t-1} and s_{t-1}.
          policy must have seen up to s_{t-1}.
        Preconditions:
          q_function must have seen up to s_t and s_t.
          policy must have seen up to s_t.
        """

        batch_state = batch['state']
        batch_action = batch['action']
        batch_size = len(batch_action)

        # Estimated policy observes s_t
        onpolicy_actions = self.policy(batch_state, test=False).sample()

        # Q(s_t, mu(s_t)) is evaluated.
        # This should not affect the internal state of Q.
        with state_kept(self.q_function):
            q = self.q_function(batch_state, onpolicy_actions, test=False)

        # Estimated Q-function observes s_t and a_t
        if isinstance(self.q_function, Recurrent):
            self.q_function.update_state(batch_state, batch_action, test=False)

        # Since we want to maximize Q, loss is negation of Q
        loss = -F.sum(q) / batch_size

        # Update stats
        self.average_actor_loss *= self.average_loss_decay
        self.average_actor_loss += ((1 - self.average_loss_decay) *
                                    float(loss.data))
        return loss

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.xp, self.phi)
        self.critic_optimizer.update(lambda: self.compute_critic_loss(batch))
        self.actor_optimizer.update(lambda: self.compute_actor_loss(batch))

    def update_from_episodes(self, episodes, errors_out=None):
        # Sort episodes desc by their lengths
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        # Precompute all the input batches
        batches = []
        for i in range(max_epi_len):
            transitions = []
            for ep in sorted_episodes:
                if len(ep) <= i:
                    break
                transitions.append(ep[i])
            batch = batch_experiences(transitions, xp=self.xp, phi=self.phi)
            batches.append(batch)

        with self.model.state_reset():
            with self.target_model.state_reset():

                # Since the target model is evaluated one-step ahead,
                # its internal states need to be updated
                self.target_q_function.update_state(batches[0]['state'],
                                                    batches[0]['action'],
                                                    test=True)
                self.target_policy(batches[0]['state'], test=True)

                # Update critic through time
                critic_loss = 0
                for batch in batches:
                    critic_loss += self.compute_critic_loss(batch)
                self.critic_optimizer.update(lambda: critic_loss / max_epi_len)

        with self.model.state_reset():

            # Update actor through time
            actor_loss = 0
            for batch in batches:
                actor_loss += self.compute_actor_loss(batch)
            self.actor_optimizer.update(lambda: actor_loss / max_epi_len)

    def act_and_train(self, state, reward):

        self.logger.debug('t:%s r:%s', self.t, reward)

        greedy_action = self.act(state)
        action = self.explorer.select_action(self.t, lambda: greedy_action)
        self.t += 1

        # Update the target network
        if self.t % self.target_update_frequency == 0:
            self.sync_target_network()

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=state,
                                      next_action=action,
                                      is_state_terminal=False)

        self.last_state = state
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        return self.last_action

    def act(self, state):

        s = self.batch_states([state], self.xp, self.phi)
        action = self.policy(s, test=True).sample()
        # Q is not needed here, but log it just for information
        q = self.q_function(s, action, test=True)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * float(q.data)

        self.logger.debug('t:%s a:%s q:%s', self.t, action.data[0], q.data)
        return cuda.to_cpu(action.data[0])

    def stop_episode_and_train(self, state, reward, done=False):

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)

        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        if isinstance(self.model, Recurrent):
            self.model.reset_state()
        self.replay_buffer.stop_current_episode()

    def get_statistics(self):
        return [
            ('average_q', self.average_q),
            ('average_actor_loss', self.average_actor_loss),
            ('average_critic_loss', self.average_critic_loss),
        ]
Example #2
0
class DQN(agent.AttributeSavingMixin, agent.Agent):
    """Deep Q-Network algorithm.

    Args:
        q_function (StateQFunction): Q-function
        optimizer (Optimizer): Optimizer that is already setup
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        clip_delta (bool): Clip delta if set True
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        average_q_decay (float): Decay rate of average Q, only used for
            recording statistics
        average_loss_decay (float): Decay rate of average loss, only used for
            recording statistics
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update (bool): Use full episodes for update if set True
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
    """

    saved_attributes = ('model', 'target_model', 'optimizer')

    def __init__(self,
                 q_function,
                 optimizer,
                 replay_buffer,
                 gamma,
                 explorer,
                 gpu=None,
                 replay_start_size=50000,
                 minibatch_size=32,
                 update_interval=1,
                 target_update_interval=10000,
                 clip_delta=True,
                 phi=lambda x: x,
                 target_update_method='hard',
                 soft_update_tau=1e-2,
                 n_times_update=1,
                 average_q_decay=0.999,
                 average_loss_decay=0.99,
                 batch_accumulator='mean',
                 episodic_update=False,
                 episodic_update_len=None,
                 logger=getLogger(__name__),
                 batch_states=batch_states):
        self.model = q_function
        self.q_function = q_function  # For backward compatibility

        if gpu is not None and gpu >= 0:
            cuda.get_device(gpu).use()
            self.model.to_gpu(device=gpu)

        self.xp = self.model.xp
        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ('mean', 'sum')
        self.logger = logger
        self.batch_states = batch_states
        if episodic_update:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=episodic_update,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = None
        self.sync_target_network()
        # For backward compatibility
        self.target_q_function = self.target_model
        self.average_q = 0
        self.average_q_decay = average_q_decay
        self.average_loss = 0
        self.average_loss_decay = average_loss_decay

    def sync_target_network(self):
        """Synchronize target network with current network."""
        if self.target_model is None:
            self.target_model = copy.deepcopy(self.model)
            call_orig = self.target_model.__call__

            def call_test(self_, x):
                with chainer.using_config('train', False):
                    return call_orig(self_, x)

            self.target_model.__call__ = call_test
        else:
            synchronize_parameters(src=self.model,
                                   dst=self.target_model,
                                   method=self.target_update_method,
                                   tau=self.soft_update_tau)

    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        This function is thread-safe.
        Args:
          experiences (list): list of dict that contains
            state: cupy.ndarray or numpy.ndarray
            action: int [0, n_action_types)
            reward: float32
            next_state: cupy.ndarray or numpy.ndarray
            next_legal_actions: list of booleans; True means legal
          gamma (float): discount factor
        Returns:
          None
        """

        has_weight = 'weight' in experiences[0]
        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      batch_states=self.batch_states)
        if has_weight:
            exp_batch['weights'] = self.xp.asarray(
                [elem['weight'] for elem in experiences],
                dtype=self.xp.float32)
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, self.gamma, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.data)

        self.model.cleargrads()
        loss.backward()
        self.optimizer.update()

    def input_initial_batch_to_target_model(self, batch):
        self.target_model(batch['state'])

    def update_from_episodes(self, episodes, errors_out=None):
        has_weights = isinstance(episodes, tuple)
        if has_weights:
            episodes, weights = episodes
            if errors_out is None:
                errors_out = []
        if errors_out is None:
            errors_out_step = None
        else:
            del errors_out[:]
            for _ in episodes:
                errors_out.append(0.0)
            errors_out_step = []
        with state_reset(self.model):
            with state_reset(self.target_model):
                loss = 0
                tmp = list(
                    reversed(
                        sorted(enumerate(episodes), key=lambda x: len(x[1]))))
                sorted_episodes = [elem[1] for elem in tmp]
                indices = [elem[0] for elem in tmp]  # argsort
                max_epi_len = len(sorted_episodes[0])
                for i in range(max_epi_len):
                    transitions = []
                    weights_step = []
                    for ep, index in zip(sorted_episodes, indices):
                        if len(ep) <= i:
                            break
                        transitions.append(ep[i])
                        if has_weights:
                            weights_step.append(weights[index])
                    batch = batch_experiences(transitions,
                                              xp=self.xp,
                                              phi=self.phi,
                                              batch_states=self.batch_states)
                    if i == 0:
                        self.input_initial_batch_to_target_model(batch)
                    if has_weights:
                        batch['weights'] = self.xp.asarray(
                            weights_step, dtype=self.xp.float32)
                    loss += self._compute_loss(batch,
                                               self.gamma,
                                               errors_out=errors_out_step)
                    if errors_out is not None:
                        for err, index in zip(errors_out_step, indices):
                            errors_out[index] += err
                loss /= max_epi_len

                # Update stats
                self.average_loss *= self.average_loss_decay
                self.average_loss += \
                    (1 - self.average_loss_decay) * float(loss.data)

                self.model.cleargrads()
                loss.backward()
                self.optimizer.update()
        if has_weights:
            self.replay_buffer.update_errors(errors_out)

    def _compute_target_values(self, exp_batch, gamma):

        batch_next_state = exp_batch['next_state']

        target_next_qout = self.target_model(batch_next_state)
        next_q_max = target_next_qout.max

        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']

        return batch_rewards + self.gamma * (1.0 - batch_terminal) * next_q_max

    def _compute_y_and_t(self, exp_batch, gamma):
        batch_size = exp_batch['reward'].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch['state']

        qout = self.model(batch_state)

        batch_actions = exp_batch['action']
        batch_q = F.reshape(qout.evaluate_actions(batch_actions),
                            (batch_size, 1))

        with chainer.no_backprop_mode():
            batch_q_target = F.reshape(
                self._compute_target_values(exp_batch, gamma), (batch_size, 1))

        return batch_q, batch_q_target

    def _compute_loss(self, exp_batch, gamma, errors_out=None):
        """Compute the Q-learning loss for a batch of experiences


        Args:
          experiences (list): see update()'s docstring
          gamma (float): discount factor
        Returns:
          loss
        """

        y, t = self._compute_y_and_t(exp_batch, gamma)

        if errors_out is not None:
            del errors_out[:]
            delta = F.sum(abs(y - t), axis=1)
            delta = cuda.to_cpu(delta.data)
            for e in delta:
                errors_out.append(e)

        if 'weights' in exp_batch:
            return compute_weighted_value_loss(
                y,
                t,
                exp_batch['weights'],
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator)
        else:
            return compute_value_loss(y,
                                      t,
                                      clip_delta=self.clip_delta,
                                      batch_accumulator=self.batch_accumulator)

    def compute_q_values(self, states):
        """Compute Q-values

        Args:
          states (list of cupy.ndarray or numpy.ndarray)
        Returns:
          list of numpy.ndarray
        """
        with chainer.using_config('train', False):
            if not states:
                return []
            batch_x = self.batch_states(states, self.xp, self.phi)
            q_values = list(cuda.to_cpu(self.model(batch_x).q_values))
            return q_values

    def _to_my_device(self, model):
        if self.gpu >= 0:
            model.to_gpu(self.gpu)
        else:
            model.to_cpu()

    def act(self, state):
        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                action_value = self.model(
                    self.batch_states([state], self.xp, self.phi))
                q = float(action_value.max.data)
                action = cuda.to_cpu(action_value.greedy_actions.data)[0]

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * q

        self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)
        return action

    def act_and_train(self, state, reward):

        with chainer.using_config('train', False):
            with chainer.no_backprop_mode():
                action_value = self.model(
                    self.batch_states([state], self.xp, self.phi))
                q = float(action_value.max.data)
                greedy_action = cuda.to_cpu(
                    action_value.greedy_actions.data)[0]

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * q

        self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)

        action = self.explorer.select_action(self.t,
                                             lambda: greedy_action,
                                             action_value=action_value)
        self.t += 1

        # Update the target network
        if self.t % self.target_update_interval == 0:
            self.sync_target_network()

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=state,
                                      next_action=action,
                                      is_state_terminal=False)

        self.last_state = state
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        self.logger.debug('t:%s r:%s a:%s', self.t, reward, action)

        return self.last_action

    def stop_episode_and_train(self, state, reward, done=False):
        """Observe a terminal state and a reward.

        This function must be called once when an episode terminates.
        """

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)

        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        if isinstance(self.model, Recurrent):
            self.model.reset_state()
        self.replay_buffer.stop_current_episode()

    def get_statistics(self):
        return [
            ('average_q', self.average_q),
            ('average_loss', self.average_loss),
        ]
Example #3
0
class PGT(AttributeSavingMixin, Agent):
    """Policy Gradient Theorem with an approximate policy and a Q-function.

    This agent is almost the same with DDPG except that it uses the likelihood
    ratio gradient estimation instead of value gradients.

    Args:
        model (chainer.Chain): Chain that contains both a policy and a
            Q-function
        actor_optimizer (Optimizer): Optimizer setup with the policy
        critic_optimizer (Optimizer): Optimizer setup with the Q-function
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id. -1 for CPU.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        average_q_decay (float): Decay rate of average Q, only used for
            recording statistics
        average_loss_decay (float): Decay rate of average loss, only used for
            recording statistics
        batch_accumulator (str): 'mean' or 'sum'
        logger (Logger): Logger used
        beta (float): Coefficient for entropy regularization
        act_deterministically (bool): Act deterministically by selecting most
            probable actions in test time
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
    """

    saved_attributes = ('model', 'target_model', 'actor_optimizer',
                        'critic_optimizer')

    def __init__(self,
                 model,
                 actor_optimizer,
                 critic_optimizer,
                 replay_buffer,
                 gamma,
                 explorer,
                 beta=1e-2,
                 act_deterministically=False,
                 gpu=-1,
                 replay_start_size=50000,
                 minibatch_size=32,
                 update_interval=1,
                 target_update_interval=10000,
                 phi=lambda x: x,
                 target_update_method='hard',
                 soft_update_tau=1e-2,
                 n_times_update=1,
                 average_q_decay=0.999,
                 average_loss_decay=0.99,
                 logger=getLogger(__name__),
                 batch_states=batch_states):

        self.model = model

        if gpu is not None and gpu >= 0:
            cuda.get_device(gpu).use()
            self.model.to_gpu(device=gpu)

        self.xp = self.model.xp
        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.average_q_decay = average_q_decay
        self.average_loss_decay = average_loss_decay
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        self.beta = beta
        self.act_deterministically = act_deterministically
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            episodic_update=False,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.batch_states = batch_states

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = copy.deepcopy(self.model)
        disable_train(self.target_model['q_function'])
        disable_train(self.target_model['policy'])
        self.average_q = 0
        self.average_actor_loss = 0.0
        self.average_critic_loss = 0.0

        # Aliases for convenience
        self.q_function = self.model['q_function']
        self.policy = self.model['policy']
        self.target_q_function = self.target_model['q_function']
        self.target_policy = self.target_model['policy']

        self.sync_target_network()

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(src=self.model,
                               dst=self.target_model,
                               method=self.target_update_method,
                               tau=self.soft_update_tau)

    def update(self, experiences, errors_out=None):
        """Update the model from experiences."""

        batch_size = len(experiences)

        batch_exp = batch_experiences(
            experiences,
            xp=self.xp,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        batch_state = batch_exp['state']
        batch_actions = batch_exp['action']
        batch_next_state = batch_exp['next_state']
        batch_rewards = batch_exp['reward']
        batch_terminal = batch_exp['is_state_terminal']
        batch_discount = batch_exp['discount']

        # Update Q-function
        def compute_critic_loss():

            with chainer.no_backprop_mode():
                pout = self.target_policy(batch_next_state)
                next_actions = pout.sample()
                next_q = self.target_q_function(batch_next_state, next_actions)
                assert next_q.shape == (batch_size, 1)

                target_q = (batch_rewards[..., None] +
                            (batch_discount[..., None] *
                             (1.0 - batch_terminal[..., None]) * next_q))
                assert target_q.shape == (batch_size, 1)

            predict_q = self.q_function(batch_state, batch_actions)
            assert predict_q.shape == (batch_size, 1)

            loss = F.mean_squared_error(target_q, predict_q)

            # Update stats
            self.average_critic_loss *= self.average_loss_decay
            self.average_critic_loss += ((1 - self.average_loss_decay) *
                                         float(loss.array))

            return loss

        def compute_actor_loss():
            pout = self.policy(batch_state)
            sampled_actions = pout.sample().array
            log_probs = pout.log_prob(sampled_actions)
            with chainer.using_config('train', False):
                q = self.q_function(batch_state, sampled_actions)
                v = self.q_function(batch_state, pout.most_probable)
            advantage = F.reshape(q - v, (batch_size, ))
            advantage = chainer.Variable(advantage.array)
            loss = - F.sum(advantage * log_probs + self.beta * pout.entropy) \
                / batch_size

            # Update stats
            self.average_actor_loss *= self.average_loss_decay
            self.average_actor_loss += ((1 - self.average_loss_decay) *
                                        float(loss.array))

            return loss

        self.critic_optimizer.update(compute_critic_loss)
        self.actor_optimizer.update(compute_actor_loss)

    def act_and_train(self, obs, reward):

        self.logger.debug('t:%s r:%s', self.t, reward)

        greedy_action = self.act(obs)
        action = self.explorer.select_action(self.t, lambda: greedy_action)
        self.t += 1

        # Update the target network
        if self.t % self.target_update_interval == 0:
            self.sync_target_network()

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=obs,
                                      next_action=action,
                                      is_state_terminal=False)

        self.last_state = obs
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        return self.last_action

    def act(self, obs):

        with chainer.using_config('train', False):
            s = self.batch_states([obs], self.xp, self.phi)
            if self.act_deterministically:
                action = self.policy(s).most_probable
            else:
                action = self.policy(s).sample()
            # Q is not needed here, but log it just for information
            q = self.q_function(s, action)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * float(q.array)

        self.logger.debug('t:%s a:%s q:%s', self.t, action.array[0], q.array)
        return cuda.to_cpu(action.array[0])

    def stop_episode_and_train(self, state, reward, done=False):

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)

        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        if isinstance(self.model, Recurrent):
            self.model.reset_state()
        self.replay_buffer.stop_current_episode()

    def select_action(self, state):
        return self.explorer.select_action(self.t, lambda: self.act(state))

    def get_statistics(self):
        return [
            ('average_q', self.average_q),
            ('average_actor_loss', self.average_actor_loss),
            ('average_critic_loss', self.average_critic_loss),
        ]
Example #4
0
class DQN(agent.AttributeSavingMixin, agent.BatchAgent):
    """Deep Q-Network algorithm.

    Args:
        q_function (StateQFunction): Q-function
        optimizer (Optimizer): Optimizer that is already setup
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        clip_delta (bool): Clip delta if set True
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        average_q_decay (float): Decay rate of average Q, only used for
            recording statistics
        average_loss_decay (float): Decay rate of average loss, only used for
            recording statistics
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update (bool): Use full episodes for update if set True
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
    """

    saved_attributes = ('model', 'target_model', 'optimizer')

    def __init__(self,
                 q_function,
                 optimizer,
                 replay_buffer,
                 gamma,
                 explorer,
                 gpu=None,
                 replay_start_size=50000,
                 minibatch_size=32,
                 update_interval=1,
                 target_update_interval=10000,
                 clip_delta=True,
                 phi=lambda x: x,
                 target_update_method='hard',
                 soft_update_tau=1e-2,
                 n_times_update=1,
                 average_q_decay=0.999,
                 average_loss_decay=0.99,
                 batch_accumulator='mean',
                 episodic_update=False,
                 episodic_update_len=None,
                 logger=getLogger(__name__),
                 batch_states=batch_states):
        self.model = q_function
        self.q_function = q_function  # For backward compatibility

        if gpu is not None and gpu >= 0:
            cuda.get_device(gpu).use()
            self.model.to_gpu(device=gpu)

        self.xp = self.model.xp
        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ('mean', 'sum')
        self.logger = logger
        self.batch_states = batch_states
        if episodic_update:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=episodic_update,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = None
        self.sync_target_network()
        # For backward compatibility
        self.target_q_function = self.target_model
        self.average_q = 0
        self.average_q_decay = average_q_decay
        self.average_loss = 0
        self.average_loss_decay = average_loss_decay

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError('Replay start size cannot exceed '
                             'replay buffer capacity.')

    def sync_target_network(self):
        """Synchronize target network with current network."""
        if self.target_model is None:
            self.target_model = copy.deepcopy(self.model)
            call_orig = self.target_model.__call__

            def call_test(self_, x):
                with chainer.using_config('train', False):
                    return call_orig(self_, x)

            self.target_model.__call__ = call_test
        else:
            synchronize_parameters(src=self.model,
                                   dst=self.target_model,
                                   method=self.target_update_method,
                                   tau=self.soft_update_tau)

    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None
        """
        has_weight = 'weight' in experiences[0][0]
        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      gamma=self.gamma,
                                      batch_states=self.batch_states)
        if has_weight:
            exp_batch['weights'] = self.xp.asarray(
                [elem[0]['weight'] for elem in experiences],
                dtype=self.xp.float32)
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.array)

        self.model.cleargrads()
        loss.backward()
        self.optimizer.update()

    def input_initial_batch_to_target_model(self, batch):
        self.target_model(batch['state'])

    def update_from_episodes(self, episodes, errors_out=None):
        has_weights = isinstance(episodes, tuple)
        if has_weights:
            episodes, weights = episodes
            if errors_out is None:
                errors_out = []
        if errors_out is None:
            errors_out_step = None
        else:
            del errors_out[:]
            for _ in episodes:
                errors_out.append(0.0)
            errors_out_step = []

        with state_reset(self.model), state_reset(self.target_model):
            loss = 0
            tmp = list(
                reversed(sorted(enumerate(episodes), key=lambda x: len(x[1]))))
            sorted_episodes = [elem[1] for elem in tmp]
            indices = [elem[0] for elem in tmp]  # argsort
            max_epi_len = len(sorted_episodes[0])
            for i in range(max_epi_len):
                transitions = []
                weights_step = []
                for ep, index in zip(sorted_episodes, indices):
                    if len(ep) <= i:
                        break
                    transitions.append([ep[i]])
                    if has_weights:
                        weights_step.append(weights[index])
                batch = batch_experiences(transitions,
                                          xp=self.xp,
                                          phi=self.phi,
                                          gamma=self.gamma,
                                          batch_states=self.batch_states)
                assert len(batch['state']) == len(transitions)
                if i == 0:
                    self.input_initial_batch_to_target_model(batch)
                if has_weights:
                    batch['weights'] = self.xp.asarray(weights_step,
                                                       dtype=self.xp.float32)
                loss += self._compute_loss(batch, errors_out=errors_out_step)
                if errors_out is not None:
                    for err, index in zip(errors_out_step, indices):
                        errors_out[index] += err
            loss /= max_epi_len

            # Update stats
            self.average_loss *= self.average_loss_decay
            self.average_loss += \
                (1 - self.average_loss_decay) * float(loss.array)

            self.model.cleargrads()
            loss.backward()
            self.optimizer.update()
        if has_weights:
            self.replay_buffer.update_errors(errors_out)

    def _compute_target_values(self, exp_batch):
        batch_next_state = exp_batch['next_state']

        target_next_qout = self.target_model(batch_next_state)
        next_q_max = target_next_qout.max

        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']
        discount = exp_batch['discount']

        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max

    def _compute_y_and_t(self, exp_batch):
        batch_size = exp_batch['reward'].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch['state']

        qout = self.model(batch_state)

        batch_actions = exp_batch['action']
        batch_q = F.reshape(qout.evaluate_actions(batch_actions),
                            (batch_size, 1))

        with chainer.no_backprop_mode():
            batch_q_target = F.reshape(self._compute_target_values(exp_batch),
                                       (batch_size, 1))

        return batch_q, batch_q_target

    def _compute_loss(self, exp_batch, errors_out=None):
        """Compute the Q-learning loss for a batch of experiences


        Args:
          exp_batch (dict): A dict of batched arrays of transitions
        Returns:
          Computed loss from the minibatch of experiences
        """
        y, t = self._compute_y_and_t(exp_batch)

        if errors_out is not None:
            del errors_out[:]
            delta = F.absolute(y - t)
            if delta.ndim == 2:
                delta = F.sum(delta, axis=1)
            delta = cuda.to_cpu(delta.array)
            for e in delta:
                errors_out.append(e)

        if 'weights' in exp_batch:
            return compute_weighted_value_loss(
                y,
                t,
                exp_batch['weights'],
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator)
        else:
            return compute_value_loss(y,
                                      t,
                                      clip_delta=self.clip_delta,
                                      batch_accumulator=self.batch_accumulator)

    def act(self, obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            action_value = self.model(
                self.batch_states([obs], self.xp, self.phi))
            q = float(action_value.max.array)
            action = cuda.to_cpu(action_value.greedy_actions.array)[0]

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * q

        self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)
        return action

    def act_and_train(self, obs, reward):

        with chainer.using_config('train', False), chainer.no_backprop_mode():
            action_value = self.model(
                self.batch_states([obs], self.xp, self.phi))
            q = float(action_value.max.array)
            greedy_action = cuda.to_cpu(action_value.greedy_actions.array)[0]

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * q

        self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)

        action = self.explorer.select_action(self.t,
                                             lambda: greedy_action,
                                             action_value=action_value)
        self.t += 1

        # Update the target network
        if self.t % self.target_update_interval == 0:
            self.sync_target_network()

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=obs,
                                      next_action=action,
                                      is_state_terminal=False)

        self.last_state = obs
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        self.logger.debug('t:%s r:%s a:%s', self.t, reward, action)

        return self.last_action

    def batch_act_and_train(self, batch_obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
            batch_av = self.model(batch_xs)
            batch_maxq = batch_av.max.array
            batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
        batch_action = [
            self.explorer.select_action(
                self.t,
                lambda: batch_argmax[i],
                action_value=batch_av[i:i + 1],
            ) for i in range(len(batch_obs))
        ]
        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * float(batch_maxq.mean())

        return batch_action

    def batch_act(self, batch_obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
            batch_av = self.model(batch_xs)
            batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
            return batch_argmax

    def batch_observe_and_train(self, batch_obs, batch_reward, batch_done,
                                batch_reset):
        for i in range(len(batch_obs)):
            self.t += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        pass

    def stop_episode_and_train(self, state, reward, done=False):
        """Observe a terminal state and a reward.

        This function must be called once when an episode terminates.
        """

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)

        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        if isinstance(self.model, Recurrent):
            self.model.reset_state()
        self.replay_buffer.stop_current_episode()

    def get_statistics(self):
        return [
            ('average_q', self.average_q),
            ('average_loss', self.average_loss),
            ('n_updates', self.optimizer.t),
        ]
Example #5
0
class TD3(AttributeSavingMixin, BatchAgent):
    """Twin Delayed Deep Deterministic Policy Gradients (TD3).

    See http://arxiv.org/abs/1802.09477

    Args:
        policy (Policy): Policy.
        q_func1 (Link): First Q-function that takes state-action pairs as input
            and outputs predicted Q-values.
        q_func2 (Link): Second Q-function that takes state-action pairs as
            input and outputs predicted Q-values.
        policy_optimizer (Optimizer): Optimizer setup with the policy
        q_func1_optimizer (Optimizer): Optimizer setup with the first
            Q-function.
        q_func2_optimizer (Optimizer): Optimizer setup with the second
            Q-function.
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        phi (callable): Feature extractor applied to observations
        soft_update_tau (float): Tau of soft target update.
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
        burnin_action_func (callable or None): If not None, this callable
            object is used to select actions before the model is updated
            one or more times during training.
        policy_update_delay (int): Delay of policy updates. Policy is updated
            once in `policy_update_delay` times of Q-function updates.
        target_policy_smoothing_func (callable): Callable that takes a batch of
            actions as input and outputs a noisy version of it. It is used for
            target policy smoothing when computing target Q-values.
    """

    saved_attributes = (
        'policy',
        'q_func1',
        'q_func2',
        'target_policy',
        'target_q_func1',
        'target_q_func2',
        'policy_optimizer',
        'q_func1_optimizer',
        'q_func2_optimizer',
    )

    def __init__(
        self,
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        replay_buffer,
        gamma,
        explorer,
        gpu=None,
        replay_start_size=10000,
        minibatch_size=100,
        update_interval=1,
        phi=lambda x: x,
        soft_update_tau=5e-3,
        n_times_update=1,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
        policy_update_delay=2,
        target_policy_smoothing_func=default_target_policy_smoothing_func,
    ):

        self.policy = policy
        self.q_func1 = q_func1
        self.q_func2 = q_func2

        if gpu is not None and gpu >= 0:
            cuda.get_device_from_id(gpu).use()
            self.policy.to_gpu(device=gpu)
            self.q_func1.to_gpu(device=gpu)
            self.q_func2.to_gpu(device=gpu)

        self.xp = self.policy.xp
        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.phi = phi
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.policy_optimizer = policy_optimizer
        self.q_func1_optimizer = q_func1_optimizer
        self.q_func2_optimizer = q_func2_optimizer
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            n_times_update=1,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
            episodic_update=False,
        )
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func
        self.policy_update_delay = policy_update_delay
        self.target_policy_smoothing_func = target_policy_smoothing_func

        self.t = 0
        self.last_state = None
        self.last_action = None

        # Target model
        self.target_policy = copy.deepcopy(self.policy)
        self.target_q_func1 = copy.deepcopy(self.q_func1)
        self.target_q_func2 = copy.deepcopy(self.q_func2)

        # Statistics
        self.q1_record = collections.deque(maxlen=1000)
        self.q2_record = collections.deque(maxlen=1000)
        self.q_func1_loss_record = collections.deque(maxlen=100)
        self.q_func2_loss_record = collections.deque(maxlen=100)

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(
            src=self.policy,
            dst=self.target_policy,
            method='soft',
            tau=self.soft_update_tau,
        )
        synchronize_parameters(
            src=self.q_func1,
            dst=self.target_q_func1,
            method='soft',
            tau=self.soft_update_tau,
        )
        synchronize_parameters(
            src=self.q_func2,
            dst=self.target_q_func2,
            method='soft',
            tau=self.soft_update_tau,
        )

    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch['next_state']
        batch_rewards = batch['reward']
        batch_terminal = batch['is_state_terminal']
        batch_state = batch['state']
        batch_actions = batch['action']
        batch_discount = batch['discount']

        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            next_actions = self.target_policy_smoothing_func(
                self.target_policy(batch_next_state).sample().array)
            next_q1 = self.target_q_func1(batch_next_state, next_actions)
            next_q2 = self.target_q_func2(batch_next_state, next_actions)
            next_q = F.minimum(next_q1, next_q2)

            target_q = batch_rewards + batch_discount * \
                (1.0 - batch_terminal) * F.flatten(next_q)

        predict_q1 = F.flatten(self.q_func1(batch_state, batch_actions))
        predict_q2 = F.flatten(self.q_func2(batch_state, batch_actions))

        loss1 = F.mean_squared_error(target_q, predict_q1)
        loss2 = F.mean_squared_error(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(cuda.to_cpu(predict_q1.array))
        self.q2_record.extend(cuda.to_cpu(predict_q2.array))
        self.q_func1_loss_record.append(float(loss1.array))
        self.q_func2_loss_record.append(float(loss2.array))

        self.q_func1_optimizer.update(lambda: loss1)
        self.q_func2_optimizer.update(lambda: loss2)

    def update_policy(self, batch):
        """Compute loss for actor."""

        batch_state = batch['state']

        onpolicy_actions = self.policy(batch_state).sample()
        q = self.q_func1(batch_state, onpolicy_actions)

        # Since we want to maximize Q, loss is negation of Q
        loss = -F.mean(q)

        self.policy_optimizer.update(lambda: loss)

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
        self.update_q_func(batch)
        if self.q_func1_optimizer.t % self.policy_update_delay == 0:
            self.update_policy(batch)
            self.sync_target_network()

    def select_onpolicy_action(self, obs):
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            s = self.batch_states([obs], self.xp, self.phi)
            action = self.policy(s).sample().array
        return cuda.to_cpu(action)[0]

    def act_and_train(self, obs, reward):

        self.logger.debug('t:%s r:%s', self.t, reward)

        if (self.burnin_action_func is not None
                and self.policy_optimizer.t == 0):
            action = self.burnin_action_func()
        else:
            onpolicy_action = self.select_onpolicy_action(obs)
            action = self.explorer.select_action(self.t,
                                                 lambda: onpolicy_action)
        self.t += 1

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=obs,
                                      next_action=action,
                                      is_state_terminal=False)

        self.last_state = obs
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        return self.last_action

    def act(self, obs):
        return self.select_onpolicy_action(obs)

    def batch_select_onpolicy_action(self, batch_obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
            batch_action = self.policy(batch_xs).sample().array
        return list(cuda.to_cpu(batch_action))

    def batch_act(self, batch_obs):
        return self.batch_select_onpolicy_action(batch_obs)

    def batch_act_and_train(self, batch_obs):
        """Select a batch of actions for training.

        Args:
            batch_obs (Sequence of ~object): Observations.

        Returns:
            Sequence of ~object: Actions.
        """

        if (self.burnin_action_func is not None
                and self.policy_optimizer.t == 0):
            batch_action = [
                self.burnin_action_func() for _ in range(len(batch_obs))
            ]
        else:
            batch_onpolicy_action = self.batch_select_onpolicy_action(
                batch_obs)
            batch_action = [
                self.explorer.select_action(self.t,
                                            lambda: batch_onpolicy_action[i])
                for i in range(len(batch_onpolicy_action))
            ]

        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def batch_observe_and_train(self, batch_obs, batch_reward, batch_done,
                                batch_reset):
        for i in range(len(batch_obs)):
            self.t += 1
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        pass

    def stop_episode_and_train(self, state, reward, done=False):

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)

        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        self.replay_buffer.stop_current_episode()

    def get_statistics(self):
        return [
            ('average_q1', _mean_or_nan(self.q1_record)),
            ('average_q2', _mean_or_nan(self.q2_record)),
            ('average_q_func1_loss', _mean_or_nan(self.q_func1_loss_record)),
            ('average_q_func2_loss', _mean_or_nan(self.q_func2_loss_record)),
            ('policy_n_updates', self.policy_optimizer.t),
            ('q_func_n_updates', self.q_func1_optimizer.t),
        ]
Example #6
0
class DDPG(AttributeSavingMixin, BatchAgent):
    """Deep Deterministic Policy Gradients.

    This can be used as SVG(0) by specifying a Gaussian policy instead of a
    deterministic policy.

    Args:
        model (DDPGModel): DDPG model that contains both a policy and a
            Q-function
        actor_optimizer (Optimizer): Optimizer setup with the policy
        critic_optimizer (Optimizer): Optimizer setup with the Q-function
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        average_q_decay (float): Decay rate of average Q, only used for
            recording statistics
        average_loss_decay (float): Decay rate of average loss, only used for
            recording statistics
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update (bool): Use full episodes for update if set True
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
        clip_critic_tgt (tuple or None) : tuple containing (min, max) to clip
            the target of the critic. If None, target will not be clipped.
        burnin_action_func (callable or None): If not None, this callable
            object is used to select actions before the model is updated
            one or more times during training.
    """

    saved_attributes = ('model', 'target_model', 'actor_optimizer',
                        'critic_optimizer', 'obs_normalizer')

    def __init__(
        self,
        model,
        actor_optimizer,
        critic_optimizer,
        replay_buffer,
        gamma,
        explorer,
        obs_normalizer=None,
        gpu=None,
        replay_start_size=50000,
        minibatch_size=32,
        update_interval=1,
        target_update_interval=10000,
        phi=lambda x: x,
        target_update_method='hard',
        soft_update_tau=1e-2,
        n_times_update=1,
        average_q_decay=0.999,
        average_loss_decay=0.99,
        episodic_update=False,
        episodic_update_len=None,
        logger=getLogger(__name__),
        batch_states=batch_states,
        l2_action_penalty=None,
        clip_critic_tgt=None,
        burnin_action_func=None,
    ):

        self.model = model
        self.obs_normalizer = obs_normalizer

        if gpu is not None and gpu >= 0:
            cuda.get_device(gpu).use()
            self.model.to_gpu(device=gpu)
            if self.obs_normalizer is not None:
                self.obs_normalizer.to_gpu(device=gpu)

        self.xp = self.model.xp
        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.average_q_decay = average_q_decay
        self.average_loss_decay = average_loss_decay
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        if episodic_update:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=episodic_update,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )
        self.batch_states = batch_states
        self.clip_critic_tgt = clip_critic_tgt
        self.l2_action_penalty = l2_action_penalty
        self.burnin_action_func = burnin_action_func

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = copy.deepcopy(self.model)
        disable_train(self.target_model['q_function'])
        disable_train(self.target_model['policy'])
        self.average_q = 0
        self.average_actor_loss = 0.0
        self.average_critic_loss = 0.0

        # Aliases for convenience
        self.q_function = self.model['q_function']
        self.policy = self.model['policy']
        self.target_q_function = self.target_model['q_function']
        self.target_policy = self.target_model['policy']

        self.sync_target_network()

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(src=self.model,
                               dst=self.target_model,
                               method=self.target_update_method,
                               tau=self.soft_update_tau)

    # Update Q-function
    def compute_critic_loss(self, batch):
        """Compute loss for critic.

        Preconditions:
          target_q_function must have seen up to s_t and a_t.
          target_policy must have seen up to s_t.
          q_function must have seen up to s_{t-1} and a_{t-1}.
        Postconditions:
          target_q_function must have seen up to s_{t+1} and a_{t+1}.
          target_policy must have seen up to s_{t+1}.
          q_function must have seen up to s_t and a_t.
        """

        batch_next_state = batch['next_state']
        batch_rewards = batch['reward']
        batch_terminal = batch['is_state_terminal']
        batch_state = batch['state']
        batch_actions = batch['action']
        batchsize = len(batch_rewards)

        with chainer.no_backprop_mode():
            # Target policy observes s_{t+1}
            next_actions = self.target_policy(batch_next_state).sample()

            # Q(s_{t+1}, mu(a_{t+1})) is evaluated.
            # This should not affect the internal state of Q.
            with state_kept(self.target_q_function):
                next_q = self.target_q_function(batch_next_state, next_actions)

            # Target Q-function observes s_{t+1} and a_{t+1}
            if isinstance(self.target_q_function, Recurrent):
                batch_next_actions = batch['next_action']
                self.target_q_function.update_state(batch_next_state,
                                                    batch_next_actions)

            target_q = batch_rewards + self.gamma * \
                (1.0 - batch_terminal) * F.reshape(next_q, (batchsize,))
            if self.clip_critic_tgt:
                target_q = F.clip(target_q, self.clip_critic_tgt[0],
                                  self.clip_critic_tgt[1])

        # Estimated Q-function observes s_t and a_t
        predict_q = F.reshape(self.q_function(batch_state, batch_actions),
                              (batchsize, ))

        loss = F.mean_squared_error(target_q, predict_q)

        # Update stats
        self.average_critic_loss *= self.average_loss_decay
        self.average_critic_loss += ((1 - self.average_loss_decay) *
                                     float(loss.array))

        return loss

    def compute_actor_loss(self, batch):
        """Compute loss for actor.

        Preconditions:
          q_function must have seen up to s_{t-1} and s_{t-1}.
          policy must have seen up to s_{t-1}.
        Postconditions:
          q_function must have seen up to s_t and s_t.
          policy must have seen up to s_t.
        """

        batch_state = batch['state']
        batch_action = batch['action']
        batch_size = len(batch_action)

        # Estimated policy observes s_t
        onpolicy_actions = self.policy(batch_state).sample()

        # Q(s_t, mu(s_t)) is evaluated.
        # This should not affect the internal state of Q.
        with state_kept(self.q_function):
            q = self.q_function(batch_state, onpolicy_actions)

        # Estimated Q-function observes s_t and a_t
        if isinstance(self.q_function, Recurrent):
            self.q_function.update_state(batch_state, batch_action)

        # Avoid the numpy #9165 bug (see also: chainer #2744)
        q = q[:, :]

        # Since we want to maximize Q, loss is negation of Q
        loss = -F.sum(q) / batch_size
        if self.l2_action_penalty:
            loss += self.l2_action_penalty \
                        * F.square(onpolicy_actions) / batch_size

        # Update stats
        self.average_actor_loss *= self.average_loss_decay
        self.average_actor_loss += ((1 - self.average_loss_decay) *
                                    float(loss.array))
        return loss

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""
        batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
        if self.obs_normalizer:
            batch['state'] = self.obs_normalizer(batch['state'], update=False)
            batch['next_state'] = self.obs_normalizer(batch['next_state'],
                                                      update=False)
        self.critic_optimizer.update(lambda: self.compute_critic_loss(batch))
        self.actor_optimizer.update(lambda: self.compute_actor_loss(batch))

    def update_from_episodes(self, episodes, errors_out=None):
        # Sort episodes desc by their lengths
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        # Precompute all the input batches
        batches = []
        for i in range(max_epi_len):
            transitions = []
            for ep in sorted_episodes:
                if len(ep) <= i:
                    break
                transitions.append([ep[i]])
            batch = batch_experiences(transitions,
                                      xp=self.xp,
                                      phi=self.phi,
                                      gamma=self.gamma)
            if self.obs_normalizer:
                batch['state'] = self.obs_normalizer(batch['state'],
                                                     update=False)
                batch['next_state'] = self.obs_normalizer(batch['state'],
                                                          update=False)
            batches.append(batch)

        with self.model.state_reset(), self.target_model.state_reset():

            # Since the target model is evaluated one-step ahead,
            # its internal states need to be updated
            self.target_q_function.update_state(batches[0]['state'],
                                                batches[0]['action'])
            self.target_policy(batches[0]['state'])

            # Update critic through time
            critic_loss = 0
            for batch in batches:
                critic_loss += self.compute_critic_loss(batch)
            self.critic_optimizer.update(lambda: critic_loss / max_epi_len)

        with self.model.state_reset():

            # Update actor through time
            actor_loss = 0
            for batch in batches:
                actor_loss += self.compute_actor_loss(batch)
            self.actor_optimizer.update(lambda: actor_loss / max_epi_len)

    def act_and_train(self, obs, reward):

        self.logger.debug('t:%s r:%s', self.t, reward)

        if (self.burnin_action_func is not None
                and self.actor_optimizer.t == 0):
            action = self.burnin_action_func()
        else:
            greedy_action = self.act(obs)
            action = self.explorer.select_action(self.t, lambda: greedy_action)
        self.t += 1

        # Update the target network
        if self.t % self.target_update_interval == 0:
            self.sync_target_network()

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=obs,
                                      next_action=action,
                                      is_state_terminal=False)
            # Add to Normalizer
            if self.obs_normalizer:
                self.obs_normalizer(self.batch_states([obs], self.xp,
                                                      self.phi))

        self.last_state = obs
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        return self.last_action

    def act(self, obs):
        with chainer.using_config('train', False):
            s = self.batch_states([obs], self.xp, self.phi)
            if self.obs_normalizer:
                s = self.obs_normalizer(s, update=False)
            action = self.policy(s).sample()
            # Q is not needed here, but log it just for information
            q = self.q_function(s, action)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * float(q.array)

        self.logger.debug('t:%s a:%s q:%s', self.t, action.array[0], q.array)
        return cuda.to_cpu(action.array[0])

    def batch_act(self, batch_obs):
        """Select a batch of actions for evaluation.

        Args:
            batch_obs (Sequence of ~object): Observations.

        Returns:
            Sequence of ~object: Actions.
        """

        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
            if self.obs_normalizer:
                batch_xs = self.obs_normalizer(batch_xs, update=False)
            batch_action = self.policy(batch_xs).sample()
            # Q is not needed here, but log it just for information
            q = self.q_function(batch_xs, batch_action)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * float(
            q.array.mean(axis=0))
        self.logger.debug('t:%s a:%s q:%s', self.t, batch_action.array[0],
                          q.array)
        return [cuda.to_cpu(action.array) for action in batch_action]

    def batch_act_and_train(self, batch_obs):
        """Select a batch of actions for training.

        Args:
            batch_obs (Sequence of ~object): Observations.

        Returns:
            Sequence of ~object: Actions.
        """

        if (self.burnin_action_func is not None
                and self.actor_optimizer.t == 0):
            batch_action = [
                self.burnin_action_func() for _ in range(len(batch_obs))
            ]
        else:
            batch_greedy_action = self.batch_act(batch_obs)
            batch_action = [
                self.explorer.select_action(self.t,
                                            lambda: batch_greedy_action[i])
                for i in range(len(batch_greedy_action))
            ]

        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def batch_observe_and_train(self, batch_obs, batch_reward, batch_done,
                                batch_reset):
        """Observe a batch of action consequences for training.

        Args:
            batch_obs (Sequence of ~object): Observations.
            batch_reward (Sequence of float): Rewards.
            batch_done (Sequence of boolean): Boolean values where True
                indicates the current state is terminal.
            batch_reset (Sequence of boolean): Boolean values where True
                indicates the current episode will be reset, even if the
                current state is not terminal.

        Returns:
            None
        """
        for i in range(len(batch_obs)):
            self.t += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        pass

    def stop_episode_and_train(self, state, reward, done=False):

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)
        # Add to Normalizer
        if self.obs_normalizer:
            self.obs_normalizer(self.batch_states([state], self.xp, self.phi))
        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        if isinstance(self.model, Recurrent):
            self.model.reset_state()
        self.replay_buffer.stop_current_episode()

    def get_statistics(self):
        return [
            ('average_q', self.average_q),
            ('average_actor_loss', self.average_actor_loss),
            ('average_critic_loss', self.average_critic_loss),
        ]
class SoftActorCritic(AttributeSavingMixin, BatchAgent):
    """Soft Actor-Critic (SAC).

    See https://arxiv.org/abs/1812.05905

    Args:
        policy (Policy): Policy.
        q_func1 (Link): First Q-function that takes state-action pairs as input
            and outputs predicted Q-values.
        q_func2 (Link): Second Q-function that takes state-action pairs as
            input and outputs predicted Q-values.
        policy_optimizer (Optimizer): Optimizer setup with the policy
        q_func1_optimizer (Optimizer): Optimizer setup with the first
            Q-function.
        q_func2_optimizer (Optimizer): Optimizer setup with the second
            Q-function.
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        phi (callable): Feature extractor applied to observations
        soft_update_tau (float): Tau of soft target update.
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
        burnin_action_func (callable or None): If not None, this callable
            object is used to select actions before the model is updated
            one or more times during training.
        initial_temperature (float): Initial temperature value. If
            `entropy_target` is set to None, the temperature is fixed to it.
        entropy_target (float or None): If set to a float, the temperature is
            adjusted during training to match the policy's entropy to it.
        temperature_optimizer (Optimizer or None): Optimizer used to optimize
            the temperature. If set to None, Adam with default hyperparameters
            is used.
        act_deterministically (bool): If set to True, choose most probable
            actions in the act method instead of sampling from distributions.
    """

    saved_attributes = (
        'policy',
        'q_func1',
        'q_func2',
        'target_q_func1',
        'target_q_func2',
        'policy_optimizer',
        'q_func1_optimizer',
        'q_func2_optimizer',
        'temperature_holder',
        'temperature_optimizer',
    )

    def __init__(
        self,
        policy,
        q_func1,
        q_func2,
        policy_optimizer,
        q_func1_optimizer,
        q_func2_optimizer,
        replay_buffer,
        gamma,
        gpu=None,
        replay_start_size=10000,
        minibatch_size=100,
        update_interval=1,
        phi=lambda x: x,
        soft_update_tau=5e-3,
        logger=getLogger(__name__),
        batch_states=batch_states,
        burnin_action_func=None,
        initial_temperature=1.,
        entropy_target=None,
        temperature_optimizer=None,
        act_deterministically=True,
    ):

        self.policy = policy
        self.q_func1 = q_func1
        self.q_func2 = q_func2

        if gpu is not None and gpu >= 0:
            cuda.get_device_from_id(gpu).use()
            self.policy.to_gpu(device=gpu)
            self.q_func1.to_gpu(device=gpu)
            self.q_func2.to_gpu(device=gpu)

        self.xp = self.policy.xp
        self.replay_buffer = replay_buffer
        self.gamma = gamma
        self.gpu = gpu
        self.phi = phi
        self.soft_update_tau = soft_update_tau
        self.logger = logger
        self.policy_optimizer = policy_optimizer
        self.q_func1_optimizer = q_func1_optimizer
        self.q_func2_optimizer = q_func2_optimizer
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=self.update,
            batchsize=minibatch_size,
            n_times_update=1,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
            episodic_update=False,
        )
        self.batch_states = batch_states
        self.burnin_action_func = burnin_action_func
        self.initial_temperature = initial_temperature
        self.entropy_target = entropy_target
        if self.entropy_target is not None:
            self.temperature_holder = TemperatureHolder(
                initial_log_temperature=np.log(initial_temperature))
            if temperature_optimizer is not None:
                self.temperature_optimizer = temperature_optimizer
            else:
                self.temperature_optimizer = chainer.optimizers.Adam()
            self.temperature_optimizer.setup(self.temperature_holder)
            if gpu is not None and gpu >= 0:
                self.temperature_holder.to_gpu(device=gpu)
        else:
            self.temperature_holder = None
            self.temperature_optimizer = None
        self.act_deterministically = act_deterministically

        self.t = 0
        self.last_state = None
        self.last_action = None

        # Target model
        self.target_q_func1 = copy.deepcopy(self.q_func1)
        self.target_q_func2 = copy.deepcopy(self.q_func2)

        # Statistics
        self.q1_record = collections.deque(maxlen=1000)
        self.q2_record = collections.deque(maxlen=1000)
        self.entropy_record = collections.deque(maxlen=1000)
        self.q_func1_loss_record = collections.deque(maxlen=100)
        self.q_func2_loss_record = collections.deque(maxlen=100)

    @property
    def temperature(self):
        if self.entropy_target is None:
            return self.initial_temperature
        else:
            with chainer.no_backprop_mode():
                return float(self.temperature_holder().array)

    def sync_target_network(self):
        """Synchronize target network with current network."""
        synchronize_parameters(
            src=self.q_func1,
            dst=self.target_q_func1,
            method='soft',
            tau=self.soft_update_tau,
        )
        synchronize_parameters(
            src=self.q_func2,
            dst=self.target_q_func2,
            method='soft',
            tau=self.soft_update_tau,
        )

    def update_q_func(self, batch):
        """Compute loss for a given Q-function."""

        batch_next_state = batch['next_state']
        batch_rewards = batch['reward']
        batch_terminal = batch['is_state_terminal']
        batch_state = batch['state']
        batch_actions = batch['action']
        batch_discount = batch['discount']

        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            next_action_distrib = self.policy(batch_next_state)
            next_actions, next_log_prob =\
                next_action_distrib.sample_with_log_prob()
            next_q1 = self.target_q_func1(batch_next_state, next_actions)
            next_q2 = self.target_q_func2(batch_next_state, next_actions)
            next_q = F.minimum(next_q1, next_q2)
            entropy_term = self.temperature * next_log_prob[..., None]
            assert next_q.shape == entropy_term.shape

            target_q = batch_rewards + batch_discount * \
                (1.0 - batch_terminal) * F.flatten(next_q - entropy_term)

        predict_q1 = F.flatten(self.q_func1(batch_state, batch_actions))
        predict_q2 = F.flatten(self.q_func2(batch_state, batch_actions))

        loss1 = 0.5 * F.mean_squared_error(target_q, predict_q1)
        loss2 = 0.5 * F.mean_squared_error(target_q, predict_q2)

        # Update stats
        self.q1_record.extend(cuda.to_cpu(predict_q1.array))
        self.q2_record.extend(cuda.to_cpu(predict_q2.array))
        self.q_func1_loss_record.append(float(loss1.array))
        self.q_func2_loss_record.append(float(loss2.array))

        self.q_func1_optimizer.update(lambda: loss1)
        self.q_func2_optimizer.update(lambda: loss2)

    def update_temperature(self, log_prob):
        assert not isinstance(log_prob, chainer.Variable)
        loss = -F.mean(
            F.broadcast_to(self.temperature_holder(), log_prob.shape) *
            (log_prob + self.entropy_target))
        self.temperature_optimizer.update(lambda: loss)

    def update_policy_and_temperature(self, batch):
        """Compute loss for actor."""

        batch_state = batch['state']

        action_distrib = self.policy(batch_state)
        actions, log_prob = action_distrib.sample_with_log_prob()
        q1 = self.q_func1(batch_state, actions)
        q2 = self.q_func2(batch_state, actions)
        q = F.minimum(q1, q2)

        entropy_term = self.temperature * log_prob[..., None]
        assert q.shape == entropy_term.shape
        loss = F.mean(entropy_term - q)

        self.policy_optimizer.update(lambda: loss)

        if self.entropy_target is not None:
            self.update_temperature(log_prob.array)

        # Record entropy
        with chainer.no_backprop_mode():
            try:
                self.entropy_record.extend(
                    cuda.to_cpu(action_distrib.entropy.array))
            except NotImplementedError:
                # Record - log p(x) instead
                self.entropy_record.extend(cuda.to_cpu(-log_prob.array))

    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
        self.update_q_func(batch)
        self.update_policy_and_temperature(batch)
        self.sync_target_network()

    def batch_select_greedy_action(self, batch_obs, deterministic=False):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
            if deterministic:
                batch_action = self.policy(batch_xs).most_probable.array
            else:
                batch_action = self.policy(batch_xs).sample().array
        return list(cuda.to_cpu(batch_action))

    def select_greedy_action(self, obs, deterministic=False):
        return self.batch_select_greedy_action([obs],
                                               deterministic=deterministic)[0]

    def act_and_train(self, obs, reward):

        self.logger.debug('t:%s r:%s', self.t, reward)

        if (self.burnin_action_func is not None
                and self.policy_optimizer.t == 0):
            action = self.burnin_action_func()
        else:
            action = self.select_greedy_action(obs)
        self.t += 1

        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            self.replay_buffer.append(state=self.last_state,
                                      action=self.last_action,
                                      reward=reward,
                                      next_state=obs,
                                      next_action=action,
                                      is_state_terminal=False)

        self.last_state = obs
        self.last_action = action

        self.replay_updater.update_if_necessary(self.t)

        return self.last_action

    def act(self, obs):
        return self.select_greedy_action(
            obs, deterministic=self.act_deterministically)

    def batch_act(self, batch_obs):
        return self.batch_select_greedy_action(
            batch_obs, deterministic=self.act_deterministically)

    def batch_act_and_train(self, batch_obs):
        """Select a batch of actions for training.

        Args:
            batch_obs (Sequence of ~object): Observations.

        Returns:
            Sequence of ~object: Actions.
        """

        if (self.burnin_action_func is not None
                and self.policy_optimizer.t == 0):
            batch_action = [
                self.burnin_action_func() for _ in range(len(batch_obs))
            ]
        else:
            batch_action = self.batch_select_greedy_action(batch_obs)

        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        return batch_action

    def batch_observe_and_train(self, batch_obs, batch_reward, batch_done,
                                batch_reset):
        for i in range(len(batch_obs)):
            self.t += 1
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                self.replay_buffer.append(
                    state=self.batch_last_obs[i],
                    action=self.batch_last_action[i],
                    reward=batch_reward[i],
                    next_state=batch_obs[i],
                    next_action=None,
                    is_state_terminal=batch_done[i],
                    env_id=i,
                )
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        pass

    def stop_episode_and_train(self, state, reward, done=False):

        assert self.last_state is not None
        assert self.last_action is not None

        # Add a transition to the replay buffer
        self.replay_buffer.append(state=self.last_state,
                                  action=self.last_action,
                                  reward=reward,
                                  next_state=state,
                                  next_action=self.last_action,
                                  is_state_terminal=done)

        self.stop_episode()

    def stop_episode(self):
        self.last_state = None
        self.last_action = None
        self.replay_buffer.stop_current_episode()

    def get_statistics(self):
        return [
            ('average_q1', _mean_or_nan(self.q1_record)),
            ('average_q2', _mean_or_nan(self.q2_record)),
            ('average_q_func1_loss', _mean_or_nan(self.q_func1_loss_record)),
            ('average_q_func2_loss', _mean_or_nan(self.q_func2_loss_record)),
            ('n_updates', self.policy_optimizer.t),
            ('average_entropy', _mean_or_nan(self.entropy_record)),
            ('temperature', self.temperature),
        ]
Example #8
0
class EVA(agent.AttributeSavingMixin, agent.BatchAgent):
    """Ephemeral Value Adjustment Algorithm.
    Args:
        q_function (StateQFunction): Q-function
        optimizer (Optimizer): Optimizer that is already setup
        replay_buffer (ReplayBuffer): Replay buffer
        gamma (float): Discount factor
        explorer (Explorer): Explorer that specifies an exploration strategy.
        gpu (int): GPU device id if not None nor negative.
        replay_start_size (int): if the replay buffer's size is less than
            replay_start_size, skip update
        minibatch_size (int): Minibatch size
        update_interval (int): Model update interval in step
        target_update_interval (int): Target model update interval in step
        clip_delta (bool): Clip delta if set True
        phi (callable): Feature extractor applied to observations
        target_update_method (str): 'hard' or 'soft'.
        soft_update_tau (float): Tau of soft target update.
        n_times_update (int): Number of repetition of update
        average_q_decay (float): Decay rate of average Q, only used for
            recording statistics
        average_loss_decay (float): Decay rate of average loss, only used for
            recording statistics
        batch_accumulator (str): 'mean' or 'sum'
        episodic_update_len (int or None): Subsequences of this length are used
            for update if set int and episodic_update=True
        logger (Logger): Logger used
        batch_states (callable): method which makes a batch of observations.
            default is `chainerrl.misc.batch_states.batch_states`
        recurrent (bool): If set to True, `model` is assumed to implement
            `chainerrl.links.StatelessRecurrent` and is updated in a recurrent
            manner.
    """

    saved_attributes = ('model', 'target_model', 'optimizer')

    def __init__(self,
                 q_function,
                 optimizer,
                 replay_buffer,
                 gamma,
                 explorer,
                 gpu=None,
                 replay_start_size=50000,
                 minibatch_size=32,
                 update_interval=1,
                 target_update_interval=10000,
                 clip_delta=True,
                 phi=lambda x: x,
                 target_update_method='hard',
                 soft_update_tau=1e-2,
                 n_times_update=1,
                 average_q_decay=0.999,
                 average_loss_decay=0.99,
                 batch_accumulator='mean',
                 episodic_update_len=None,
                 logger=getLogger(__name__),
                 batch_states=batch_states,
                 recurrent=False,
                 len_trajectory=50,
                 periodic_steps=20):
        self.model = q_function
        self.q_function = q_function  # For backward compatibility

        if gpu is not None and gpu >= 0:
            cuda.get_device_from_id(gpu).use()
            self.model.to_gpu(device=gpu)

        self.xp = self.model.xp
        self.replay_buffer = replay_buffer
        self.optimizer = optimizer
        self.gamma = gamma
        self.explorer = explorer
        self.gpu = gpu
        self.target_update_interval = target_update_interval
        self.clip_delta = clip_delta
        self.phi = phi
        self.target_update_method = target_update_method
        self.soft_update_tau = soft_update_tau
        self.batch_accumulator = batch_accumulator
        assert batch_accumulator in ('mean', 'sum')
        self.logger = logger
        self.batch_states = batch_states
        self.recurrent = recurrent
        if self.recurrent:
            update_func = self.update_from_episodes
        else:
            update_func = self.update
        self.replay_updater = ReplayUpdater(
            replay_buffer=replay_buffer,
            update_func=update_func,
            batchsize=minibatch_size,
            episodic_update=recurrent,
            episodic_update_len=episodic_update_len,
            n_times_update=n_times_update,
            replay_start_size=replay_start_size,
            update_interval=update_interval,
        )

        self.t = 0
        self.last_state = None
        self.last_action = None
        self.target_model = None
        self.sync_target_network()
        # For backward compatibility
        self.target_q_function = self.target_model
        self.average_q = 0
        self.average_q_decay = average_q_decay
        self.average_loss = 0
        self.average_loss_decay = average_loss_decay

        # Recurrent states of the model
        self.train_recurrent_states = None
        self.train_prev_recurrent_states = None
        self.test_recurrent_states = None

        # Error checking
        if (self.replay_buffer.capacity is not None
                and self.replay_buffer.capacity <
                self.replay_updater.replay_start_size):
            raise ValueError('Replay start size cannot exceed '
                             'replay buffer capacity.')

        self.last_embed = None
        self.len_trajectory = len_trajectory
        self.num_actions = self.model.num_actions
        self.periodic_steps = periodic_steps
        self.value_buffer = self.model.non_q
        self.current_t = 0

    def sync_target_network(self):
        """Synchronize target network with current network."""
        if self.target_model is None:
            self.target_model = copy.deepcopy(self.model.q_func)
            call_orig = self.target_model.__call__

            def call_test(self_, x):
                with chainer.using_config('train', False):
                    return call_orig(self_, x)

            self.target_model.__call__ = call_test
        else:
            synchronize_parameters(src=self.model.q_func,
                                   dst=self.target_model,
                                   method=self.target_update_method,
                                   tau=self.soft_update_tau)

    def update(self, experiences, errors_out=None):
        """Update the model from experiences
        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.
        Returns:
            None
        """
        has_weight = 'weight' in experiences[0][0]
        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      gamma=self.gamma,
                                      batch_states=self.batch_states)
        if has_weight:
            exp_batch['weights'] = self.xp.asarray(
                [elem[0]['weight'] for elem in experiences],
                dtype=self.xp.float32)
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.array)

        self.model.cleargrads()
        loss.backward()
        self.optimizer.update()

    def update_from_episodes(self, episodes, errors_out=None):
        assert errors_out is None,\
            "Recurrent DQN does not support PrioritizedBuffer"
        exp_batch = batch_recurrent_experiences(
            episodes,
            model=self.model,
            xp=self.xp,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        loss = self._compute_loss(exp_batch, errors_out=None)
        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.array)
        self.optimizer.update(lambda: loss)

    def _compute_target_values(self, exp_batch):
        batch_next_state = exp_batch['next_state']

        if self.recurrent:
            target_next_qout, _ = self.target_model.n_step_forward(
                batch_next_state,
                exp_batch['next_recurrent_state'],
                output_mode='concat')
        else:
            target_next_qout = self.target_model(batch_next_state)
        next_q_max = target_next_qout.max

        batch_rewards = exp_batch['reward']
        batch_terminal = exp_batch['is_state_terminal']
        discount = exp_batch['discount']

        return batch_rewards + discount * (1.0 - batch_terminal) * next_q_max

    def _compute_y_and_t(self, exp_batch):
        batch_size = exp_batch['reward'].shape[0]

        # Compute Q-values for current states
        batch_state = exp_batch['state']

        if self.recurrent:
            qout, _ = self.model.n_step_forward(
                batch_state,
                exp_batch['recurrent_state'],
                output_mode='concat',
            )
        else:
            qout = self.model(batch_state)

        batch_actions = exp_batch['action']
        batch_q = F.reshape(qout.evaluate_actions(batch_actions),
                            (batch_size, 1))

        with chainer.no_backprop_mode():
            batch_q_target = F.reshape(self._compute_target_values(exp_batch),
                                       (batch_size, 1))

        return batch_q, batch_q_target

    def _compute_loss(self, exp_batch, errors_out=None):
        """Compute the Q-learning loss for a batch of experiences
        Args:
          exp_batch (dict): A dict of batched arrays of transitions
        Returns:
          Computed loss from the minibatch of experiences
        """
        y, t = self._compute_y_and_t(exp_batch)

        if errors_out is not None:
            del errors_out[:]
            delta = F.absolute(y - t)
            if delta.ndim == 2:
                delta = F.sum(delta, axis=1)
            delta = cuda.to_cpu(delta.array)
            for e in delta:
                errors_out.append(e)

        if 'weights' in exp_batch:
            return compute_weighted_value_loss(
                y,
                t,
                exp_batch['weights'],
                clip_delta=self.clip_delta,
                batch_accumulator=self.batch_accumulator)
        else:
            return compute_value_loss(y,
                                      t,
                                      clip_delta=self.clip_delta,
                                      batch_accumulator=self.batch_accumulator)

    def act(self, obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            action_value =\
                self._evaluate_model_and_update_recurrent_states(
                    [obs], test=True)
            q = float(action_value.max.array)
            action = cuda.to_cpu(action_value.greedy_actions.array)[0]
            embed = cuda.to_cpu(self.model.get_embedding().array)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * q

        self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)

        self.backup_store_if_necessary(embed, self.current_t)
        self.current_t += 1
        return action

    def act_and_train(self, obs, reward):

        # Observe the consequences
        if self.last_state is not None:
            assert self.last_action is not None
            # Add a transition to the replay buffer
            transition = {
                'state': self.last_state,
                'action': self.last_action,
                'reward': reward,
                'embedding': self.last_embed[0],
                'next_state': obs,
                'is_state_terminal': False,
            }
            if self.recurrent:
                transition['recurrent_state'] =\
                    self.model.get_recurrent_state_at(
                        self.train_prev_recurrent_states,
                        0, unwrap_variable=True)
                self.train_prev_recurrent_states = None
                transition['next_recurrent_state'] =\
                    self.model.get_recurrent_state_at(
                        self.train_recurrent_states, 0, unwrap_variable=True)
            self.replay_buffer.append(**transition)

        # Update the target network
        if self.t % self.target_update_interval == 0:
            self.sync_target_network()

        # Update the model
        self.replay_updater.update_if_necessary(self.t)

        # Choose an action
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            action_value =\
                self._evaluate_model_and_update_recurrent_states(
                    [obs], test=False)
            q = float(action_value.max.array)
            greedy_action = cuda.to_cpu(action_value.greedy_actions.array)[0]
            embed = cuda.to_cpu(self.model.get_embedding().array)
        action = self.explorer.select_action(self.t,
                                             lambda: greedy_action,
                                             action_value=action_value)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * q

        self.t += 1
        self.last_state = obs
        self.last_action = action
        self.last_embed = embed

        self.logger.debug('t:%s q:%s action_value:%s', self.t, q, action_value)
        self.logger.debug('t:%s r:%s a:%s', self.t, reward, action)

        self.backup_store_if_necessary(self.last_embed, self.t)

        return self.last_action

    def _evaluate_model_and_update_recurrent_states(self, batch_obs, test):
        batch_xs = self.batch_states(batch_obs, self.xp, self.phi)
        if self.recurrent:
            if test:
                batch_av, self.test_recurrent_states = self.model(
                    batch_xs, self.test_recurrent_states)
            else:
                self.train_prev_recurrent_states = self.train_recurrent_states
                batch_av, self.train_recurrent_states = self.model(
                    batch_xs, self.train_recurrent_states)
        else:
            batch_av = self.model(
                batch_xs,
                eva=(len(
                    self.value_buffer.embeddings) == self.value_buffer.capacity
                     ))
        return batch_av

    def batch_act_and_train(self, batch_obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_av = self._evaluate_model_and_update_recurrent_states(
                batch_obs, test=False)
            batch_maxq = batch_av.max.array
            batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
        batch_action = [
            self.explorer.select_action(
                self.t,
                lambda: batch_argmax[i],
                action_value=batch_av[i:i + 1],
            ) for i in range(len(batch_obs))
        ]
        self.batch_last_obs = list(batch_obs)
        self.batch_last_action = list(batch_action)

        # Update stats
        self.average_q *= self.average_q_decay
        self.average_q += (1 - self.average_q_decay) * float(batch_maxq.mean())

        return batch_action

    def batch_act(self, batch_obs):
        with chainer.using_config('train', False), chainer.no_backprop_mode():
            batch_av = self._evaluate_model_and_update_recurrent_states(
                batch_obs, test=True)
            batch_argmax = cuda.to_cpu(batch_av.greedy_actions.array)
            return batch_argmax

    def batch_observe_and_train(self, batch_obs, batch_reward, batch_done,
                                batch_reset):
        for i in range(len(batch_obs)):
            self.t += 1
            # Update the target network
            if self.t % self.target_update_interval == 0:
                self.sync_target_network()
            if self.batch_last_obs[i] is not None:
                assert self.batch_last_action[i] is not None
                # Add a transition to the replay buffer
                transition = {
                    'state': self.batch_last_obs[i],
                    'action': self.batch_last_action[i],
                    'reward': batch_reward[i],
                    'next_state': batch_obs[i],
                    'next_action': None,
                    'is_state_terminal': batch_done[i],
                }
                if self.recurrent:
                    transition['recurrent_state'] =\
                        self.model.get_recurrent_state_at(
                            self.train_prev_recurrent_states,
                            i, unwrap_variable=True)
                    transition['next_recurrent_state'] =\
                        self.model.get_recurrent_state_at(
                            self.train_recurrent_states,
                            i, unwrap_variable=True)
                self.replay_buffer.append(env_id=i, **transition)
                if batch_reset[i] or batch_done[i]:
                    self.batch_last_obs[i] = None
                    self.batch_last_action[i] = None
                    self.replay_buffer.stop_current_episode(env_id=i)
            self.replay_updater.update_if_necessary(self.t)

        if self.recurrent:
            # Reset recurrent states when episodes end
            self.train_prev_recurrent_states = None
            self.train_recurrent_states =\
                _batch_reset_recurrent_states_when_episodes_end(
                    model=self.model,
                    batch_done=batch_done,
                    batch_reset=batch_reset,
                    recurrent_states=self.train_recurrent_states,
                )

    def batch_observe(self, batch_obs, batch_reward, batch_done, batch_reset):
        if self.recurrent:
            # Reset recurrent states when episodes end
            self.test_recurrent_states =\
                _batch_reset_recurrent_states_when_episodes_end(
                    model=self.model,
                    batch_done=batch_done,
                    batch_reset=batch_reset,
                    recurrent_states=self.test_recurrent_states,
                )

    def stop_episode_and_train(self, state, reward, done=False):
        """Observe a terminal state and a reward.
        This function must be called once when an episode terminates.
        """

        assert self.last_state is not None
        assert self.last_action is not None
        assert self.last_embed is not None

        # Add a transition to the replay buffer
        transition = {
            'state': self.last_state,
            'action': self.last_action,
            'reward': reward,
            'embedding': self.last_embed[0],
            'next_state': state,
            'next_action': self.last_action,
            'is_state_terminal': done,
        }
        if self.recurrent:
            transition['recurrent_state'] =\
                self.model.get_recurrent_state_at(
                    self.train_prev_recurrent_states, 0, unwrap_variable=True)
            self.train_prev_recurrent_states = None
            transition['next_recurrent_state'] =\
                self.model.get_recurrent_state_at(
                    self.train_recurrent_states, 0, unwrap_variable=True)
        self.replay_buffer.append(**transition)

        self.backup_store_if_necessary(self.last_embed, self.t)

        self.last_state = None
        self.last_action = None
        self.last_embed = None
        self.current_t = 0

        if self.recurrent:
            self.train_recurrent_states = None

        self.replay_buffer.stop_current_episode()

    def stop_episode(self):
        if self.recurrent:
            self.test_recurrent_states = None

    def get_statistics(self):
        return [
            ('average_q', self.average_q),
            ('average_loss', self.average_loss),
            ('n_updates', self.optimizer.t),
        ]

    def backup_store_if_necessary(self, embedding, t):
        if self.model.lambdas == 0 or self.model.lambdas == 1:
            return
        if (t % self.periodic_steps
                == 0) and (self.t >= self.replay_buffer.capacity):
            self.replay_buffer.update_embedding()
            trajectories = self.replay_buffer.lookup(embedding)
            batch_trajectory = [{
                'state':
                batch_states([elem[0]['state']
                              for elem in traject], self.xp, self.phi),
                'action': [elem[0]['action'] for elem in traject],
                'reward': [elem[0]['reward'] for elem in traject],
                'embedding': [elem[0]['embedding'] for elem in traject]
            } for traject in trajectories]

            qnp, embeddings = self._trajectory_centric_planning(
                batch_trajectory)
            self.value_buffer.store(embeddings, qnp)

    def _trajectory_centric_planning(self, trajectories):
        #atari
        embeddings = []
        batch_state = []
        for trajectory in trajectories:
            embeddings += trajectory['embedding']
            batch = self.xp.empty((self.len_trajectory, 4, 84, 84),
                                  dtype=self.xp.float32)
            batch[:len(trajectory['state'])] = trajectory['state']
            batch_state.append(batch)

        batch_state = self.xp.concatenate(batch_state,
                                          axis=0).astype(self.xp.float32)

        with chainer.using_config('train', False), chainer.no_backprop_mode():
            parametric_q = self.model(batch_state, eva=False)
            parametric_q = cuda.to_cpu(parametric_q.q_values.array).reshape(
                (len(trajectories), self.len_trajectory, self.num_actions))

        q_value = []
        for qnp, trajectory in zip(parametric_q, trajectories):
            action = trajectory['action']
            reward = trajectory['reward']
            T = len(action)
            qnp = qnp[:T]
            Vnp = np.max(qnp[T - 1])
            for t in range(T - 2, -1, -1):
                qnp[t][action[t]] = reward[t] + self.gamma * Vnp
                Vnp = np.max(qnp[t])
            q_value.append(qnp)

        return self.xp.asarray(
            np.concatenate(q_value, axis=0).astype(
                np.float32)), self.xp.asarray(embeddings)