def _update_vf(self, dataset): """Update the value function using a given dataset. The value function is updated via SGD to minimize TD(lambda) errors. """ assert "state" in dataset[0] assert "v_teacher" in dataset[0] for batch in _yield_minibatches( dataset, minibatch_size=self.vf_batch_size, num_epochs=self.vf_epochs ): states = batch_states([b["state"] for b in batch], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) vs_teacher = torch.as_tensor( [b["v_teacher"] for b in batch], device=self.device, dtype=torch.float, ) vs_pred = self.vf(states) vf_loss = F.mse_loss(vs_pred, vs_teacher[..., None]) self.vf.zero_grad() vf_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm) self.vf_optimizer.step()
def update_policy_with_goal(self, batch): """Compute loss for actor.""" batch_state = batch["state"] batch_goal = batch["goal"] action_distrib = self.policy(torch.cat([batch_state, batch_goal], -1)) onpolicy_actions = action_distrib.rsample() entropy_term = 0 if self.add_entropy: log_prob = action_distrib.log_prob(onpolicy_actions) entropy_term = self.temperature * log_prob[..., None] q = self.q_func1((torch.cat([batch_state, batch_goal], -1), onpolicy_actions)) # Since we want to maximize Q, loss is negation of Q loss = -torch.mean(-entropy_term + q) self.policy_loss_record.append(float(loss)) self.policy_optimizer.zero_grad() loss.backward() # get policy gradients gradients = self.get_and_flatten_policy_gradients() gradient_variance = torch.var(gradients) gradient_mean = torch.mean(gradients) self.policy_gradients_variance_record.append(float(gradient_variance)) self.policy_gradients_mean_record.append(float(gradient_mean)) if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.policy_n_updates += 1
def _update_vf_once_recurrent(self, episodes): # Sort episodes desc by length for pack_sequence episodes = sorted(episodes, key=len, reverse=True) flat_transitions = flatten_sequences_time_first(episodes) # Prepare data for a recurrent model seqs_states = [] for ep in episodes: states = self.batch_states( [transition["state"] for transition in ep], self.device, self.phi) if self.obs_normalizer: states = self.obs_normalizer(states, update=False) seqs_states.append(states) flat_vs_teacher = torch.as_tensor( [[transition["v_teacher"]] for transition in flat_transitions], device=self.device, dtype=torch.float, ) with torch.no_grad(): vf_rs = concatenate_recurrent_states( _collect_first_recurrent_states_of_vf(episodes)) flat_vs_pred, _ = pack_and_forward(self.vf, seqs_states, vf_rs) vf_loss = F.mse_loss(flat_vs_pred, flat_vs_teacher) self.vf.zero_grad() vf_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.vf.parameters(), self.max_grad_norm) self.vf_optimizer.step()
def update_temperature(self, log_prob): assert not log_prob.requires_grad loss = -torch.mean(self.temperature_holder() * (log_prob + self.entropy_target)) self.temperature_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.temperature_holder.parameters(), self.max_grad_norm) self.temperature_optimizer.step()
def batch_update(self): assert len(self.reward_sequences) == self.batchsize assert len(self.log_prob_sequences) == self.batchsize assert len(self.entropy_sequences) == self.batchsize # Update the model assert self.n_backward == 0 self.accumulate_grad() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.n_backward = 0
def update_q_func(self, batch): """Compute loss for a given Q-function.""" batch_next_state = batch["next_state"].float() batch_rewards = batch["reward"].float() batch_terminal = batch["is_state_terminal"].float() batch_state = batch["state"].float() batch_actions = batch["action"].float() batch_discount = batch["discount"].float() with torch.no_grad(), pfrl.utils.evaluating( self.policy), pfrl.utils.evaluating( self.target_q_func1), pfrl.utils.evaluating( self.target_q_func2): next_action_distrib = self.policy(batch_next_state.float()) next_actions_normalized = next_action_distrib.sample() next_actions = self.scale * next_actions_normalized next_log_prob = next_action_distrib.log_prob( next_actions_normalized) next_q1 = self.target_q_func1((batch_next_state, next_actions)) next_q2 = self.target_q_func2((batch_next_state, next_actions)) next_q = torch.min(next_q1, next_q2) entropy_term = self.temperature * next_log_prob[..., None] assert next_q.shape == entropy_term.shape target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal) * torch.flatten(next_q - entropy_term) predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions))) predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions))) loss1 = 0.5 * F.mse_loss(target_q, predict_q1) loss2 = 0.5 * F.mse_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step()
def update_policy_with_goal(self, batch): """Compute loss for actor.""" batch_state = batch["state"] batch_goal = batch["goal"] action_distrib = self.policy(torch.cat([batch_state, batch_goal], -1)) onpolicy_actions_normalized = action_distrib.rsample() onpolicy_actions = self.scale * onpolicy_actions_normalized entropy_term = 0 if self.add_entropy: log_prob = action_distrib.log_prob(onpolicy_actions_normalized) entropy_term = self.temperature * log_prob[..., None] if self.entropy_target is not None: self.update_temperature(log_prob.detach()) self.entropy_record.append(float(torch.mean(-entropy_term))) self.temperature_record.append(self.temperature) q = self.q_func1((torch.cat([batch_state, batch_goal], -1), onpolicy_actions)) # Since we want to maximize Q, loss is negation of Q loss = -torch.mean(-entropy_term + q) self.policy_loss_record.append(float(loss)) self.policy_optimizer.zero_grad() loss.backward() # get policy gradients # gradients = self.get_and_flatten_policy_gradients() # gradient_variance = torch.var(gradients) # gradient_mean = torch.mean(gradients) gradient_variance = 0 gradient_mean = 0 self.policy_gradients_variance_record.append(float(gradient_variance)) self.policy_gradients_mean_record.append(float(gradient_mean)) if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.policy_n_updates += 1 self.kl_divergence = self.compute_kl(self.policy, self.target_policy, batch_state, batch_goal) self.one_step_kl_divergence = self.compute_kl(self.policy, self.prior_policy, batch_state, batch_goal) self.prior_policy = copy.deepcopy(self.policy)
def _test_clip_l2_grad_norm_(gpu): if gpu >= 0: device = torch.device("cuda:{}".format(gpu)) else: device = torch.device("cpu") model = nn.Sequential( nn.Linear(2, 10), nn.ReLU(), nn.Linear(10, 3), ).to(device) x = torch.rand(7, 2).to(device) def backward(): model.zero_grad() loss = model(x).mean() loss.backward() backward() raw_grads = _get_grad_vector(model) # Threshold large enough not to affect grads th = 10000 backward() nn.utils.clip_grad_norm_(model.parameters(), th) clipped_grads = _get_grad_vector(model) backward() clip_l2_grad_norm_(model.parameters(), th) our_clipped_grads = _get_grad_vector(model) np.testing.assert_allclose(raw_grads, clipped_grads) np.testing.assert_allclose(raw_grads, our_clipped_grads) # Threshold small enough to affect grads th = 1e-2 backward() nn.utils.clip_grad_norm_(model.parameters(), th) clipped_grads = _get_grad_vector(model) backward() clip_l2_grad_norm_(model.parameters(), th) our_clipped_grads = _get_grad_vector(model) with pytest.raises(AssertionError): np.testing.assert_allclose(raw_grads, clipped_grads, rtol=1e-5) with pytest.raises(AssertionError): np.testing.assert_allclose(raw_grads, our_clipped_grads, rtol=1e-5) np.testing.assert_allclose(clipped_grads, our_clipped_grads, rtol=1e-5)
def update_q_func(self, batch): """Compute loss for a given Q-function.""" batch_next_state = batch["next_state"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_actions = batch["action"] batch_discount = batch["discount"] with torch.no_grad(), pfrl.utils.evaluating( self.target_policy), pfrl.utils.evaluating( self.target_q_func1), pfrl.utils.evaluating( self.target_q_func2): next_actions = self.target_policy_smoothing_func( self.target_policy(batch_next_state).sample()) next_q1 = self.target_q_func1((batch_next_state, next_actions)) next_q2 = self.target_q_func2((batch_next_state, next_actions)) next_q = torch.min(next_q1, next_q2) target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal) * torch.flatten(next_q) predict_q1 = torch.flatten(self.q_func1((batch_state, batch_actions))) predict_q2 = torch.flatten(self.q_func2((batch_state, batch_actions))) loss1 = F.mse_loss(target_q, predict_q1) loss2 = F.mse_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step() self.q_func_n_updates += 1
def update(self): with torch.no_grad(): _, next_value = self.model(self.states[-1]) next_value = next_value[:, 0] self._compute_returns(next_value) pout, values = self.model(self.states[:-1].reshape(-1, *self.obs_shape)) actions = self.actions.reshape(-1, *self.action_shape) dist_entropy = pout.entropy().mean() action_log_probs = pout.log_prob(actions) values = values.reshape((self.update_steps, self.num_processes)) action_log_probs = action_log_probs.reshape( (self.update_steps, self.num_processes) ) advantages = self.returns[:-1] - values value_loss = (advantages * advantages).mean() action_loss = -(advantages.detach() * action_log_probs).mean() self.optimizer.zero_grad() ( value_loss * self.v_loss_coef + action_loss * self.pi_loss_coef - dist_entropy * self.entropy_coeff ).backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.states[0] = self.states[-1] self.t_start = self.t # Update stats self.average_actor_loss += (1 - self.average_actor_loss_decay) * ( float(action_loss) - self.average_actor_loss ) self.average_value += (1 - self.average_value_decay) * ( float(value_loss) - self.average_value ) self.average_entropy += (1 - self.average_entropy_decay) * ( float(dist_entropy) - self.average_entropy )
def update_policy(self, batch): """Compute loss for actor.""" batch_state = batch["state"] onpolicy_actions = self.policy(batch_state).rsample() q = self.q_func1((batch_state, onpolicy_actions)) # Since we want to maximize Q, loss is negation of Q loss = -torch.mean(q) self.policy_loss_record.append(float(loss)) self.policy_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.policy_n_updates += 1
def update(self): self._compute_returns() pout, values = self.model(self.states[:self.update_steps].reshape( -1, *self.obs_shape)) actions = self.actions[:self.update_steps].reshape( -1, *self.action_shape) dist_entropy = pout.entropy().mean() action_log_probs = pout.log_prob(actions) values = values.reshape((self.update_steps, self.num_processes)) action_log_probs = action_log_probs.reshape( (self.update_steps, self.num_processes)) advantages = self.returns - values value_loss = (advantages * advantages).mean() action_loss = -(advantages.detach() * action_log_probs).mean() self.optimizer.zero_grad() (value_loss * self.v_loss_coef + action_loss * self.pi_loss_coef - dist_entropy * self.entropy_coeff).backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() # NOTE: Update time-step self.t_start += self.update_steps # sliding window self.states[:self.update_steps] = self.states[self.update_steps:-1] self.actions[:self.update_steps] = self.actions[self.update_steps:] self.rewards[:self.update_steps] = self.rewards[self.update_steps:] self.value_preds[:self. update_steps] = self.value_preds[self.update_steps:-1] # Update stats self.average_actor_loss += (1 - self.average_actor_loss_decay) * ( float(action_loss) - self.average_actor_loss) self.average_value += (1 - self.average_value_decay) * ( float(value_loss) - self.average_value) self.average_entropy += (1 - self.average_entropy_decay) * ( float(dist_entropy) - self.average_entropy)
def update( self, t_start, t_stop, R, actions, rewards, values, action_values, action_distribs, action_distribs_mu, avg_action_distribs, ): assert np.isscalar(R) self.assert_shared_memory() total_loss = self.compute_loss( t_start=t_start, t_stop=t_stop, R=R, actions=actions, rewards=rewards, values=values, action_values=action_values, action_distribs=action_distribs, action_distribs_mu=action_distribs_mu, avg_action_distribs=avg_action_distribs, ) # Compute gradients using thread-specific model self.model.zero_grad() total_loss.squeeze().backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) # Copy the gradients to the globally shared model copy_param.copy_grad(target_link=self.shared_model, source_link=self.model) self.optimizer.step() self.sync_parameters()
def update_policy_and_temperature(self, batch): """Compute loss for actor.""" batch_state = batch["state"].float() action_distrib = self.policy(batch_state) actions_normalized = action_distrib.rsample() actions = self.scale * actions_normalized log_prob = action_distrib.log_prob(actions_normalized) q1 = self.q_func1((batch_state, actions)) q2 = self.q_func2((batch_state, actions)) q = torch.min(q1, q2) entropy_term = self.temperature * log_prob[..., None] assert q.shape == entropy_term.shape loss = torch.mean(entropy_term - q) self.policy_optimizer.zero_grad() loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy_optimizer.step() self.n_policy_updates += 1 if self.entropy_target is not None: self.update_temperature(log_prob.detach()) # Record entropy with torch.no_grad(): try: self.entropy_record.extend( action_distrib.entropy().detach().cpu().numpy()) except NotImplementedError: # Record - log p(x) instead self.entropy_record.extend(-log_prob.detach().cpu().numpy())
def update_q_func_with_goal(self, batch): """ Compute loss for a given Q-function, or critics """ batch_next_state = batch["next_state"] batch_next_goal = batch["next_goal"] batch_rewards = batch["reward"] batch_terminal = batch["is_state_terminal"] batch_state = batch["state"] batch_goal = batch["goal"] batch_actions = batch["action"] batch_discount = batch["discount"] with torch.no_grad(), pfrl.utils.evaluating( self.target_policy), pfrl.utils.evaluating( self.target_q_func1), pfrl.utils.evaluating( self.target_q_func2): next_action_distrib = self.target_policy( torch.cat([batch_next_state, batch_next_goal], -1)) next_actions = self.target_policy_smoothing_func( next_action_distrib.sample()) entropy_term = 0 if self.add_entropy: next_log_prob = next_action_distrib.log_prob(next_actions) entropy_term = self.temperature * next_log_prob[..., None] next_q1 = self.target_q_func1( (torch.cat([batch_next_state, batch_next_goal], -1), next_actions)) next_q2 = self.target_q_func2( (torch.cat([batch_next_state, batch_next_goal], -1), next_actions)) next_q = torch.min(next_q1, next_q2) target_q = batch_rewards + batch_discount * ( 1.0 - batch_terminal) * torch.flatten(next_q - entropy_term) predict_q1 = torch.flatten( self.q_func1((torch.cat([batch_state, batch_goal], -1), batch_actions))) predict_q2 = torch.flatten( self.q_func2((torch.cat([batch_state, batch_goal], -1), batch_actions))) loss1 = F.smooth_l1_loss(target_q, predict_q1) loss2 = F.smooth_l1_loss(target_q, predict_q2) # Update stats self.q1_record.extend(predict_q1.detach().cpu().numpy()) self.q2_record.extend(predict_q2.detach().cpu().numpy()) self.q_func1_loss_record.append(float(loss1)) self.q_func2_loss_record.append(float(loss2)) q1_recent_variance = np.var( list(self.q1_record)[-self.recent_variance_size:]) q2_recent_variance = np.var( list(self.q2_record)[-self.recent_variance_size:]) self.q_func1_variance_record.append(q1_recent_variance) self.q_func2_variance_record.append(q2_recent_variance) self.q_func1_optimizer.zero_grad() loss1.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func1.parameters(), self.max_grad_norm) self.q_func1_optimizer.step() self.q_func2_optimizer.zero_grad() loss2.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.q_func2.parameters(), self.max_grad_norm) self.q_func2_optimizer.step() self.q_func_n_updates += 1
def update_with_accumulated_grad(self): assert self.n_backward == self.batchsize if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) self.optimizer.step() self.n_backward = 0
def update(self, statevar): assert self.t_start < self.t n = self.t - self.t_start self.assert_shared_memory() if statevar is None: R = 0 else: with torch.no_grad(), pfrl.utils.evaluating(self.model): if self.recurrent: (_, vout), _ = one_step_forward(self.model, statevar, self.train_recurrent_states) else: _, vout = self.model(statevar) R = float(vout) pi_loss_factor = self.pi_loss_coef v_loss_factor = self.v_loss_coef # Normalize the loss of sequences truncated by terminal states if self.keep_loss_scale_same and self.t - self.t_start < self.t_max: factor = self.t_max / (self.t - self.t_start) pi_loss_factor *= factor v_loss_factor *= factor if self.normalize_grad_by_t_max: pi_loss_factor /= self.t - self.t_start v_loss_factor /= self.t - self.t_start # Batch re-compute for efficient backprop batch_obs = self.batch_states( [self.past_obs[i] for i in range(self.t_start, self.t)], self.device, self.phi, ) if self.recurrent: (batch_distrib, batch_v), _ = pack_and_forward( self.model, [batch_obs], self.past_recurrent_state[self.t_start], ) else: batch_distrib, batch_v = self.model(batch_obs) batch_action = torch.stack( [self.past_action[i] for i in range(self.t_start, self.t)]) batch_log_prob = batch_distrib.log_prob(batch_action) batch_entropy = batch_distrib.entropy() rev_returns = [] for i in reversed(range(self.t_start, self.t)): R *= self.gamma R += self.past_rewards[i] rev_returns.append(R) batch_return = torch.as_tensor(list(reversed(rev_returns)), dtype=torch.float) batch_adv = batch_return - batch_v.detach().squeeze(-1) assert batch_log_prob.shape == (n, ) assert batch_adv.shape == (n, ) assert batch_entropy.shape == (n, ) pi_loss = torch.sum(-batch_adv * batch_log_prob - self.beta * batch_entropy, dim=0) assert batch_v.shape == (n, 1) assert batch_return.shape == (n, ) v_loss = F.mse_loss(batch_v, batch_return[..., None], reduction="sum") / 2 if pi_loss_factor != 1.0: pi_loss *= pi_loss_factor if v_loss_factor != 1.0: v_loss *= v_loss_factor if self.process_idx == 0: logger.debug("pi_loss:%s v_loss:%s", pi_loss, v_loss) total_loss = torch.squeeze(pi_loss) + torch.squeeze(v_loss) # Compute gradients using thread-specific model self.model.zero_grad() total_loss.backward() if self.max_grad_norm is not None: clip_l2_grad_norm_(self.model.parameters(), self.max_grad_norm) # Copy the gradients to the globally shared model copy_param.copy_grad(target_link=self.shared_model, source_link=self.model) # Update the globally shared model self.optimizer.step() if self.process_idx == 0: logger.debug("update") self.sync_parameters() self.past_obs = {} self.past_action = {} self.past_rewards = {} self.past_recurrent_state = {} self.t_start = self.t
def our_clip(): clip_l2_grad_norm_(model.parameters(), th)