Exemple #1
0
 def update(self, experiences, errors_out=None):
     """Update the model from experiences"""
     batch = batch_experiences(experiences, self.device, self.phi,
                               self.gamma)
     self.update_q_func(batch)
     self.update_policy_and_temperature(batch)
     self.sync_target_network()
Exemple #2
0
    def update_from_episodes(self, episodes, errors_out=None):
        assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer"
        episodes = sorted(episodes, key=len, reverse=True)
        exp_batch = batch_recurrent_experiences(
            episodes,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        demo_experiences = load_experiences_from_demonstrations(
            self.expert_dataset, self.replay_updater.batchsize,
            self.reward_scale)
        demo_batch = batch_experiences(
            demo_experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        loss = self._compute_loss(exp_batch, demo_batch, errors_out=None)
        self.loss_record.append(float(loss.detach().cpu().numpy()))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.optim_t += 1
Exemple #3
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.device, self.phi,
                                  self.gamma)
        self.update_q_func(batch)
        if self.q_func_n_updates % self.policy_update_delay == 0:
            self.update_policy(batch)
            self.sync_target_network()
Exemple #4
0
 def test_batch_experiences(self):
     experiences = []
     experiences.append([
         dict(
             state=1,
             action=1,
             reward=1,
             next_state=i,
             next_action=1,
             is_state_terminal=False,
         ) for i in range(3)
     ])
     experiences.append([
         dict(
             state=1,
             action=1,
             reward=1,
             next_state=1,
             next_action=1,
             is_state_terminal=False,
         )
     ])
     four_step_transition = [
         dict(
             state=1,
             action=1,
             reward=1,
             next_state=1,
             next_action=1,
             is_state_terminal=False,
         )
     ] * 3
     four_step_transition.append(
         dict(
             state=1,
             action=1,
             reward=1,
             next_state=5,
             next_action=1,
             is_state_terminal=True,
         ))
     experiences.append(four_step_transition)
     batch = replay_buffer.batch_experiences(experiences,
                                             torch.device("cpu"),
                                             lambda x: x, 0.99)
     self.assertEqual(batch["state"][0], 1)
     self.assertSequenceEqual(
         list(batch["is_state_terminal"]),
         list(np.asarray([0.0, 0.0, 1.0], dtype=np.float32)),
     )
     self.assertSequenceEqual(
         list(batch["discount"]),
         list(np.asarray([0.99**3, 0.99**1, 0.99**4], dtype=np.float32)),
     )
     self.assertSequenceEqual(list(batch["next_state"]),
                              list(np.asarray([2, 1, 5])))
Exemple #5
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None
        """
        has_weight = "weight" in experiences[0][0]
        exp_batch = batch_experiences(
            experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        if self.rnd_reward:
            self.rnd_module.train(exp_batch)
        # if self.ngu_reward:
        #     self.ngu_module.train(exp_batch)
        if has_weight:
            exp_batch["weights"] = torch.tensor(  # pylint: disable=not-callable
                [elem[0]["weight"] for elem in experiences],
                device=self.device,
                dtype=torch.float32,
            )
            if errors_out is None:
                errors_out = []
        loss = self._compute_loss(exp_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        self.loss_record.append(float(loss.detach().cpu().numpy()))

        self.optimizer.zero_grad()
        loss.backward()
        if self.max_grad_norm is not None:
            pfrl.utils.clip_l2_grad_norm_(self.model.parameters(),
                                          self.max_grad_norm)
        self.optimizer.step()
        self.optim_t += 1
Exemple #6
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences"""

        batch = batch_experiences(experiences, self.device, self.phi,
                                  self.gamma)

        self.critic_optimizer.zero_grad()
        self.compute_critic_loss(batch).backward()
        self.critic_optimizer.step()

        self.actor_optimizer.zero_grad()
        self.compute_actor_loss(batch).backward()
        self.actor_optimizer.step()

        self.n_updates += 1
Exemple #7
0
    def update_from_episodes(self, episodes, errors_out=None):
        raise NotImplementedError

        # Sort episodes desc by their lengths
        sorted_episodes = list(reversed(sorted(episodes, key=len)))
        max_epi_len = len(sorted_episodes[0])

        # Precompute all the input batches
        batches = []
        for i in range(max_epi_len):
            transitions = []
            for ep in sorted_episodes:
                if len(ep) <= i:
                    break
                transitions.append([ep[i]])
            batch = batch_experiences(transitions,
                                      xp=self.device,
                                      phi=self.phi,
                                      gamma=self.gamma)
            batches.append(batch)

        with self.model.state_reset(), self.target_model.state_reset():

            # Since the target model is evaluated one-step ahead,
            # its internal states need to be updated
            self.target_q_function.update_state(batches[0]["state"],
                                                batches[0]["action"])
            self.target_policy(batches[0]["state"])

            # Update critic through time
            critic_loss = 0
            for batch in batches:
                critic_loss += self.compute_critic_loss(batch)
            self.critic_optimizer.update(lambda: critic_loss / max_epi_len)

        with self.model.state_reset():

            # Update actor through time
            actor_loss = 0
            for batch in batches:
                actor_loss += self.compute_actor_loss(batch)
            self.actor_optimizer.update(lambda: actor_loss / max_epi_len)
Exemple #8
0
    def update(self, experiences, errors_out=None):
        """Update the model from experiences

        Args:
            experiences (list): List of lists of dicts.
                For DQN, each dict must contains:
                  - state (object): State
                  - action (object): Action
                  - reward (float): Reward
                  - is_state_terminal (bool): True iff next state is terminal
                  - next_state (object): Next state
                  - weight (float, optional): Weight coefficient. It can be
                    used for importance sampling.
            errors_out (list or None): If set to a list, then TD-errors
                computed from the given experiences are appended to the list.

        Returns:
            None

        Changes from DQN:
            Learned from demonstrations
        """
        has_weight = "weight" in experiences[0][0]
        exp_batch = batch_experiences(
            experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )
        if has_weight:
            exp_batch["weights"] = torch.tensor(
                [elem[0]["weight"] for elem in experiences],
                device=self.device,
                dtype=torch.float32,
            )
            if errors_out is None:
                errors_out = []

        if self.reward_based_sampler is not None:
            demo_experiences = self.reward_based_sampler.sample(experiences)
        else:
            demo_experiences = load_experiences_from_demonstrations(
                self.expert_dataset, self.replay_updater.batchsize,
                self.reward_scale)
        demo_batch = batch_experiences(
            demo_experiences,
            device=self.device,
            phi=self.phi,
            gamma=self.gamma,
            batch_states=self.batch_states,
        )

        loss = self._compute_loss(exp_batch, demo_batch, errors_out=errors_out)
        if has_weight:
            self.replay_buffer.update_errors(errors_out)

        self.loss_record.append(float(loss.detach().cpu().numpy()))

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.optim_t += 1