Esempio n. 1
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        This function is thread-safe.
        Args:
          experiences (list): list of dict that contains
            state: cupy.ndarray or numpy.ndarray
            action: int [0, n_action_types)
            reward: float32
            next_state: cupy.ndarray or numpy.ndarray
            next_legal_actions: list of booleans; True means legal
          gamma (float): discount factor
        Returns:
          None
        """

        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      batch_states=self.batch_states)
        loss = self._compute_loss(exp_batch, self.gamma, errors_out=errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.data)

        self.optimizer.zero_grads()
        loss.backward()
        self.optimizer.update()
Esempio n. 2
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
        self.update_q_func(batch)
        self.update_policy_and_temperature(batch)
        self.sync_target_network()
Esempio n. 3
0
 def test_batch_experiences(self):
     experiences = []
     experiences.append(
         [dict(state=1, action=1, reward=1, next_state=i,
               next_action=1, is_state_terminal=False) for i in range(3)])
     experiences.append([dict(state=1, action=1, reward=1, next_state=1,
                              next_action=1, is_state_terminal=False)])
     four_step_transition = [dict(
         state=1, action=1, reward=1, next_state=1,
         next_action=1, is_state_terminal=False)] * 3
     four_step_transition.append(dict(
                                 state=1, action=1, reward=1, next_state=5,
                                 next_action=1, is_state_terminal=True))
     experiences.append(four_step_transition)
     batch = replay_buffer.batch_experiences(
         experiences, np, lambda x: x, 0.99)
     self.assertEqual(batch['state'][0], 1)
     self.assertSequenceEqual(list(batch['is_state_terminal']),
                              list(np.asarray([0.0, 0.0, 1.0],
                                              dtype=np.float32)))
     self.assertSequenceEqual(list(batch['discount']),
                              list(np.asarray([
                                   0.99 ** 3, 0.99 ** 1, 0.99 ** 4],
                                  dtype=np.float32)))
     self.assertSequenceEqual(list(batch['next_state']),
                              list(np.asarray([2, 1, 5])))
Esempio n. 4
0
    def update_from_episodes(self, episodes, errors_out=None):
        has_weights = isinstance(episodes, tuple)
        if has_weights:
            episodes, weights = episodes
            if errors_out is None:
                errors_out = []
        if errors_out is None:
            errors_out_step = None
        else:
            del errors_out[:]
            for _ in episodes:
                errors_out.append(0.0)
            errors_out_step = []

        with state_reset(self.model), state_reset(self.target_model):
            loss = 0
            tmp = list(reversed(sorted(
                enumerate(episodes), key=lambda x: len(x[1]))))
            sorted_episodes = [elem[1] for elem in tmp]
            indices = [elem[0] for elem in tmp]  # argsort
            max_epi_len = len(sorted_episodes[0])
            for i in range(max_epi_len):
                transitions = []
                weights_step = []
                for ep, index in zip(sorted_episodes, indices):
                    if len(ep) <= i:
                        break
                    transitions.append([ep[i]])
                    if has_weights:
                        weights_step.append(weights[index])
                batch = batch_experiences(
                    transitions,
                    xp=self.xp,
                    phi=self.phi,
                    gamma=self.gamma,
                    batch_states=self.batch_states)
                assert len(batch['state']) == len(transitions)
                if i == 0:
                    self.input_initial_batch_to_target_model(batch)
                if has_weights:
                    batch['weights'] = self.xp.asarray(
                        weights_step, dtype=self.xp.float32)
                loss += self._compute_loss(batch,
                                           errors_out=errors_out_step)
                if errors_out is not None:
                    for err, index in zip(errors_out_step, indices):
                        errors_out[index] += err
            loss /= max_epi_len

            # Update stats
            self.average_loss *= self.average_loss_decay
            self.average_loss += \
                (1 - self.average_loss_decay) * float(loss.array)

            self.model.cleargrads()
            loss.backward()
            self.optimizer.update()
        if has_weights:
            self.replay_buffer.update_errors(errors_out)
Esempio n. 5
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
        self.update_q_func(batch)
        if self.q_func1_optimizer.t % self.policy_update_delay == 0:
            self.update_policy(batch)
            self.sync_target_network()
Esempio n. 6
0
 def update(self, experiences, errors_out=None):
     """Update the model from experiences"""
     batch = batch_experiences(experiences, self.xp, self.phi, self.gamma)
     if self.obs_normalizer:
         batch['state'] = self.obs_normalizer(batch['state'], update=False)
         batch['next_state'] = self.obs_normalizer(batch['next_state'],
                                                   update=False)
     self.critic_optimizer.update(lambda: self.compute_critic_loss(batch))
     self.actor_optimizer.update(lambda: self.compute_actor_loss(batch))
Esempio n. 7
0
 def extra_update(agent, all_experiences):
     for epoch in range(args.cpo_distill_epochs):
         n_samples = len(all_experiences)
         indices = np.asarray(range(n_samples))
         np.random.shuffle(indices)
         for start_idx in (range(0, n_samples,
                                 args.cpo_distill_batch_size)):
             batch_idx = indices[start_idx:start_idx +
                                 args.cpo_distill_batch_size].astype(
                                     np.int32)
             experiences = [all_experiences[idx] for idx in batch_idx]
             batch = batch_experiences(experiences, agent.xp, agent.phi,
                                       agent.gamma)
             agent.update_policy_and_temperature(batch)
Esempio n. 8
0
    def update_from_episodes(self, episodes, errors_out=None):
        # Sort episodes desc by their lengths
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        # Precompute all the input batches
        batches = []
        for i in range(max_epi_len):
            transitions = []
            for ep in sorted_episodes:
                if len(ep) <= i:
                    break
                transitions.append([ep[i]])
            batch = batch_experiences(transitions,
                                      xp=self.xp,
                                      phi=self.phi,
                                      gamma=self.gamma)
            if self.obs_normalizer:
                batch['state'] = self.obs_normalizer(batch['state'],
                                                     update=False)
                batch['next_state'] = self.obs_normalizer(batch['state'],
                                                          update=False)
            batches.append(batch)

        with self.model.state_reset(), self.target_model.state_reset():

            # Since the target model is evaluated one-step ahead,
            # its internal states need to be updated
            self.target_q_function.update_state(batches[0]['state'],
                                                batches[0]['action'])
            self.target_policy(batches[0]['state'])

            # Update critic through time
            critic_loss = 0
            for batch in batches:
                critic_loss += self.compute_critic_loss(batch)
            self.critic_optimizer.update(lambda: critic_loss / max_epi_len)

        with self.model.state_reset():

            # Update actor through time
            actor_loss = 0
            for batch in batches:
                actor_loss += self.compute_actor_loss(batch)
            self.actor_optimizer.update(lambda: actor_loss / max_epi_len)
Esempio n. 9
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None
        """
        has_weight = 'weight' in experiences[0][0]
        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      gamma=self.gamma,
                                      batch_states=self.batch_states)
        if has_weight:
            exp_batch['weights'] = self.xp.asarray(
                [elem[0]['weight'] for elem in experiences],
                dtype=self.xp.float32)
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.array)

        self.model.cleargrads()
        loss.backward()
        self.optimizer.update()
Esempio n. 10
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        This function is thread-safe.
        Args:
          experiences (list): list of lists of dicts.
          The dict contains
            state: cupy.ndarray or numpy.ndarray
            action: int [0, n_action_types)
            reward: float32
            is_state_terminal: bool
            next_state: cupy.ndarray or numpy.ndarray
            weight (optional): float32
        Returns:
          None
        """
        has_weight = 'weight' in experiences[0][0]
        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      gamma=self.gamma,
                                      batch_states=self.batch_states)
        if has_weight:
            exp_batch['weights'] = self.xp.asarray(
                [elem[0]['weight'] for elem in experiences],
                dtype=self.xp.float32)
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.array)

        self.model.cleargrads()
        loss.backward()
        self.optimizer.update()
Esempio n. 11
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        This function is thread-safe.
        Args:
          experiences (list): list of dict that contains
            state: cupy.ndarray or numpy.ndarray
            action: int [0, n_action_types)
            reward: float32
            next_state: cupy.ndarray or numpy.ndarray
            next_legal_actions: list of booleans; True means legal
          gamma (float): discount factor
        Returns:
          None
        """

        has_weight = 'weight' in experiences[0]
        exp_batch = batch_experiences(experiences,
                                      xp=self.xp,
                                      phi=self.phi,
                                      batch_states=self.batch_states)
        if has_weight:
            exp_batch['weights'] = self.xp.asarray(
                [elem['weight'] for elem in experiences],
                dtype=self.xp.float32)
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, self.gamma, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        # Update stats
        self.average_loss *= self.average_loss_decay
        self.average_loss += (1 - self.average_loss_decay) * float(loss.data)

        self.model.cleargrads()
        loss.backward()
        self.optimizer.update()
Esempio n. 12
0
 def update_from_episodes(self, episodes, errors_out=None):
     with state_reset(self.model):
         with state_reset(self.target_model):
             loss = 0
             sorted_episodes = list(reversed(sorted(episodes, key=len)))
             max_epi_len = len(sorted_episodes[0])
             for i in range(max_epi_len):
                 transitions = []
                 for ep in sorted_episodes:
                     if len(ep) <= i:
                         break
                     transitions.append(ep[i])
                 batch = batch_experiences(transitions,
                                           xp=self.xp,
                                           phi=self.phi,
                                           batch_states=self.batch_states)
                 if i == 0:
                     self.input_initial_batch_to_target_model(batch)
                 loss += self._compute_loss(batch, self.gamma)
             loss /= max_epi_len
             self.optimizer.zero_grads()
             loss.backward()
             self.optimizer.update()
Esempio n. 13
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.xp, self.phi)
        self.critic_optimizer.update(lambda: self.compute_critic_loss(batch))
        self.actor_optimizer.update(lambda: self.compute_actor_loss(batch))
Esempio n. 14
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences."""

        batch_size = len(experiences)

        batch_exp = batch_experiences(
            experiences,
            xp=self.xp,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        batch_state = batch_exp['state']
        batch_actions = batch_exp['action']
        batch_next_state = batch_exp['next_state']
        batch_rewards = batch_exp['reward']
        batch_terminal = batch_exp['is_state_terminal']
        batch_discount = batch_exp['discount']

        # Update Q-function
        def compute_critic_loss():

            with chainer.no_backprop_mode():
                pout = self.target_policy(batch_next_state)
                next_actions = pout.sample()
                next_q = self.target_q_function(batch_next_state, next_actions)
                assert next_q.shape == (batch_size, 1)

                target_q = (batch_rewards[..., None] +
                            (batch_discount[..., None] *
                             (1.0 - batch_terminal[..., None]) * next_q))
                assert target_q.shape == (batch_size, 1)

            predict_q = self.q_function(batch_state, batch_actions)
            assert predict_q.shape == (batch_size, 1)

            loss = F.mean_squared_error(target_q, predict_q)

            # Update stats
            self.average_critic_loss *= self.average_loss_decay
            self.average_critic_loss += ((1 - self.average_loss_decay) *
                                         float(loss.array))

            return loss

        def compute_actor_loss():
            pout = self.policy(batch_state)
            sampled_actions = pout.sample().array
            log_probs = pout.log_prob(sampled_actions)
            with chainer.using_config('train', False):
                q = self.q_function(batch_state, sampled_actions)
                v = self.q_function(batch_state, pout.most_probable)
            advantage = F.reshape(q - v, (batch_size, ))
            advantage = chainer.Variable(advantage.array)
            loss = - F.sum(advantage * log_probs + self.beta * pout.entropy) \
                / batch_size

            # Update stats
            self.average_actor_loss *= self.average_loss_decay
            self.average_actor_loss += ((1 - self.average_loss_decay) *
                                        float(loss.array))

            return loss

        self.critic_optimizer.update(compute_critic_loss)
        self.actor_optimizer.update(compute_actor_loss)
Esempio n. 15
0
    def update_from_replay(self):

        if self.replay_buffer is None:
            return

        if len(self.replay_buffer) < self.replay_start_size:
            return

        if self.replay_buffer.n_episodes < self.batchsize:
            return

        if self.process_idx == 0:
            self.logger.debug('update_from_replay')

        episodes = self.replay_buffer.sample_episodes(self.batchsize,
                                                      max_len=self.t_max)
        if isinstance(episodes, tuple):
            # Prioritized replay
            episodes, weights = episodes
        else:
            weights = [1] * len(episodes)
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        with state_reset(self.model):
            # Batch computation of multiple episodes
            rewards = {}
            values = {}
            next_values = {}
            log_probs = {}
            next_action_distrib = None
            next_v = None
            for t in range(max_epi_len):
                transitions = []
                for ep in sorted_episodes:
                    if len(ep) <= t:
                        break
                    transitions.append([ep[t]])
                batch = batch_experiences(transitions,
                                          xp=self.xp,
                                          phi=self.phi,
                                          gamma=self.gamma,
                                          batch_states=self.batch_states)
                batchsize = batch['action'].shape[0]
                if next_action_distrib is not None:
                    action_distrib = next_action_distrib[0:batchsize]
                    v = next_v[0:batchsize]
                else:
                    action_distrib, v = self.model(batch['state'])
                next_action_distrib, next_v = self.model(batch['next_state'])
                values[t] = v
                next_values[t] = next_v * \
                    (1 - batch['is_state_terminal'].reshape(next_v.shape))
                rewards[t] = chainer.cuda.to_cpu(batch['reward'])
                log_probs[t] = action_distrib.log_prob(batch['action'])
            # Loss is computed one by one episode
            losses = []
            for i, ep in enumerate(sorted_episodes):
                e_values = {}
                e_next_values = {}
                e_rewards = {}
                e_log_probs = {}
                for t in range(len(ep)):
                    assert values[t].shape[0] > i
                    assert next_values[t].shape[0] > i
                    assert rewards[t].shape[0] > i
                    assert log_probs[t].shape[0] > i
                    e_values[t] = values[t][i:i + 1]
                    e_next_values[t] = next_values[t][i:i + 1]
                    e_rewards[t] = float(rewards[t][i:i + 1])
                    e_log_probs[t] = log_probs[t][i:i + 1]
                losses.append(
                    self.compute_loss(t_start=0,
                                      t_stop=len(ep),
                                      rewards=e_rewards,
                                      values=e_values,
                                      next_values=e_next_values,
                                      log_probs=e_log_probs))
            loss = chainerrl.functions.weighted_sum_arrays(
                losses, weights) / self.batchsize
            self.update(loss)