def _backward(self, batch): """Loss is encoded in here. Defining a new loss function would start by rewriting this function""" states, acs, advs, rs, _ = convert_batch(batch) values, ac_logprobs, entropy = self._evaluate(states, acs) pi_err = -(advs * ac_logprobs).sum() value_err = 0.5 * (values - rs).pow(2).sum() self.optimizer.zero_grad() overall_err = 0.5 * value_err + pi_err - entropy * 0.01 overall_err.backward() torch.nn.utils.clip_grad_norm(self._model.parameters(), 40)
def _backward(self, batch): """Loss is encoded in here. Defining a new loss function would start by rewriting this function""" states, acs, advs, rs, _ = convert_batch(batch) values, ac_logprobs, entropy = self._evaluate(states, acs) pi_err = -(advs * ac_logprobs).sum() value_err = 0.5 * (values - rs).pow(2).sum() self.optimizer.zero_grad() overall_err = (pi_err + value_err * self.config["vf_loss_coeff"] + entropy * self.config["entropy_coeff"]) overall_err.backward() torch.nn.utils.clip_grad_norm( self._model.parameters(), self.config["grad_clip"])
def _backward(self, batch): """Loss is encoded in here. Defining a new loss function would start by rewriting this function""" states, actions, advs, rs, _ = convert_batch(batch) values, action_log_probs, entropy = self._evaluate(states, actions) pi_err = -advs.dot(action_log_probs.reshape(-1)) value_err = F.mse_loss(values.reshape(-1), rs) self.optimizer.zero_grad() overall_err = sum([ pi_err, self.config["vf_loss_coeff"] * value_err, self.config["entropy_coeff"] * entropy, ]) overall_err.backward() torch.nn.utils.clip_grad_norm_(self._model.parameters(), self.config["grad_clip"])