def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.update_q_func(batch) self.update_policy_and_temperature(batch) self.sync_target_network()
def update_from_episodes(self, episodes, errors_out=None): assert errors_out is None, "Recurrent DQN does not support PrioritizedBuffer" episodes = sorted(episodes, key=len, reverse=True) exp_batch = batch_recurrent_experiences( episodes, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) demo_experiences = load_experiences_from_demonstrations( self.expert_dataset, self.replay_updater.batchsize, self.reward_scale) demo_batch = batch_experiences( demo_experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, demo_batch, errors_out=None) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_t += 1
def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.update_q_func(batch) if self.q_func_n_updates % self.policy_update_delay == 0: self.update_policy(batch) self.sync_target_network()
def test_batch_experiences(self): experiences = [] experiences.append([ dict( state=1, action=1, reward=1, next_state=i, next_action=1, is_state_terminal=False, ) for i in range(3) ]) experiences.append([ dict( state=1, action=1, reward=1, next_state=1, next_action=1, is_state_terminal=False, ) ]) four_step_transition = [ dict( state=1, action=1, reward=1, next_state=1, next_action=1, is_state_terminal=False, ) ] * 3 four_step_transition.append( dict( state=1, action=1, reward=1, next_state=5, next_action=1, is_state_terminal=True, )) experiences.append(four_step_transition) batch = replay_buffer.batch_experiences(experiences, torch.device("cpu"), lambda x: x, 0.99) self.assertEqual(batch["state"][0], 1) self.assertSequenceEqual( list(batch["is_state_terminal"]), list(np.asarray([0.0, 0.0, 1.0], dtype=np.float32)), ) self.assertSequenceEqual( list(batch["discount"]), list(np.asarray([0.99**3, 0.99**1, 0.99**4], dtype=np.float32)), ) self.assertSequenceEqual(list(batch["next_state"]), list(np.asarray([2, 1, 5])))
def update(self, experiences, errors_out=None): """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if self.rnd_reward: self.rnd_module.train(exp_batch) # if self.ngu_reward: # self.ngu_module.train(exp_batch) if has_weight: exp_batch["weights"] = torch.tensor( # pylint: disable=not-callable [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] loss = self._compute_loss(exp_batch, errors_out=errors_out) if has_weight: self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: pfrl.utils.clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.optim_t += 1
def update(self, experiences, errors_out=None): """Update the model from experiences""" batch = batch_experiences(experiences, self.device, self.phi, self.gamma) self.critic_optimizer.zero_grad() self.compute_critic_loss(batch).backward() self.critic_optimizer.step() self.actor_optimizer.zero_grad() self.compute_actor_loss(batch).backward() self.actor_optimizer.step() self.n_updates += 1
def update_from_episodes(self, episodes, errors_out=None): raise NotImplementedError # Sort episodes desc by their lengths sorted_episodes = list(reversed(sorted(episodes, key=len))) max_epi_len = len(sorted_episodes[0]) # Precompute all the input batches batches = [] for i in range(max_epi_len): transitions = [] for ep in sorted_episodes: if len(ep) <= i: break transitions.append([ep[i]]) batch = batch_experiences(transitions, xp=self.device, phi=self.phi, gamma=self.gamma) batches.append(batch) with self.model.state_reset(), self.target_model.state_reset(): # Since the target model is evaluated one-step ahead, # its internal states need to be updated self.target_q_function.update_state(batches[0]["state"], batches[0]["action"]) self.target_policy(batches[0]["state"]) # Update critic through time critic_loss = 0 for batch in batches: critic_loss += self.compute_critic_loss(batch) self.critic_optimizer.update(lambda: critic_loss / max_epi_len) with self.model.state_reset(): # Update actor through time actor_loss = 0 for batch in batches: actor_loss += self.compute_actor_loss(batch) self.actor_optimizer.update(lambda: actor_loss / max_epi_len)
def update(self, experiences, errors_out=None): """Update the model from experiences Args: experiences (list): List of lists of dicts. For DQN, each dict must contains: - state (object): State - action (object): Action - reward (float): Reward - is_state_terminal (bool): True iff next state is terminal - next_state (object): Next state - weight (float, optional): Weight coefficient. It can be used for importance sampling. errors_out (list or None): If set to a list, then TD-errors computed from the given experiences are appended to the list. Returns: None Changes from DQN: Learned from demonstrations """ has_weight = "weight" in experiences[0][0] exp_batch = batch_experiences( experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) if has_weight: exp_batch["weights"] = torch.tensor( [elem[0]["weight"] for elem in experiences], device=self.device, dtype=torch.float32, ) if errors_out is None: errors_out = [] if self.reward_based_sampler is not None: demo_experiences = self.reward_based_sampler.sample(experiences) else: demo_experiences = load_experiences_from_demonstrations( self.expert_dataset, self.replay_updater.batchsize, self.reward_scale) demo_batch = batch_experiences( demo_experiences, device=self.device, phi=self.phi, gamma=self.gamma, batch_states=self.batch_states, ) loss = self._compute_loss(exp_batch, demo_batch, errors_out=errors_out) if has_weight: self.replay_buffer.update_errors(errors_out) self.loss_record.append(float(loss.detach().cpu().numpy())) self.optimizer.zero_grad() loss.backward() self.optimizer.step() self.optim_t += 1