def forward(self, ob, state_in=None): if isinstance(ob, PackedSequence): ob = ob.data logits = np.random.rand(ob.shape[0], 2) if state_in is None: state_in = torch.from_numpy(np.zeros(10)) return CatDist(torch.from_numpy(logits)), state_in + 1
def loss(self, batch): """Loss function.""" dist = self.pi(batch['obs']).dist q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value # alpha loss if self.automatic_entropy_tuning: ent_error = dist.entropy() - self.target_entropy alpha_loss = self.log_alpha * ent_error.detach().mean() self.opt_alpha.zero_grad() alpha_loss.backward() self.opt_alpha.step() alpha = self.log_alpha.exp() else: alpha = self.alpha alpha_loss = 0 # qf loss with torch.no_grad(): next_dist = self.pi(batch['next_obs']).dist q1_next = self.target_qf1(batch['next_obs']).qvals q2_next = self.target_qf2(batch['next_obs']).qvals qmin = torch.min(q1_next, q2_next) # explicitly compute the expectation over next actions qnext = torch.sum(qmin * next_dist.probs, dim=1) + alpha * next_dist.entropy() qtarg = batch['reward'] + (1.0 - batch['done']) * self.gamma * qnext assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf1_loss = self.mse_loss(q1, qtarg) qf2_loss = self.mse_loss(q2, qtarg) # pi loss pi_loss = None if self.t % self.policy_update_period == 0: with torch.no_grad(): q1_pi = self.qf1(batch['obs']).qvals q2_pi = self.qf2(batch['obs']).qvals min_q_pi = torch.min(q1_pi, q2_pi) assert min_q_pi.shape == dist.logits.shape target_dist = CatDist(logits=min_q_pi) pi_dist = CatDist(logits=alpha * dist.logits) pi_loss = pi_dist.kl(target_dist).mean() # log pi loss about as frequently as other losses if self.t % self.log_period < self.policy_update_period: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) if self.t % self.log_period < self.update_period: if self.automatic_entropy_tuning: logger.add_scalar('alg/log_alpha', self.log_alpha.detach().cpu().numpy(), self.t, time.time()) scalars = { "target": self.target_entropy, "entropy": dist.entropy().mean().detach().cpu().numpy().item() } logger.add_scalars('alg/entropy', scalars, self.t, time.time()) else: logger.add_scalar( 'alg/entropy', dist.entropy().mean().detach().cpu().numpy().item(), self.t, time.time()) logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time()) logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time()) logger.add_scalar('alg/qf1', q1.mean().detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/qf2', q2.mean().detach().cpu().numpy(), self.t, time.time()) return pi_loss, qf1_loss, qf2_loss
def loss(self, batch): """Loss function.""" pi_out = self.pi(batch['obs'], reparameterization_trick=self.rsample) if self.discrete: new_ac = pi_out.action ent = pi_out.dist.entropy() else: assert isinstance(pi_out.dist, TanhNormal), ( "It is strongly encouraged that you use a TanhNormal " "action distribution for continuous action spaces.") if self.rsample: new_ac, new_pth_ac = pi_out.dist.rsample( return_pretanh_value=True) else: new_ac, new_pth_ac = pi_out.dist.sample( return_pretanh_value=True) logp = pi_out.dist.log_prob(new_ac, new_pth_ac) q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value v = self.vf(batch['obs']).value # alpha loss if self.automatic_entropy_tuning: if self.discrete: ent_error = -ent + self.target_entropy else: ent_error = logp + self.target_entropy alpha_loss = -(self.log_alpha * ent_error.detach()).mean() self.opt_alpha.zero_grad() alpha_loss.backward() self.opt_alpha.step() alpha = self.log_alpha.exp() else: alpha = 1 alpha_loss = 0 # qf loss vtarg = self.target_vf(batch['next_obs']).value qtarg = self.reward_scale * batch['reward'].float() + ( (1.0 - batch['done']) * self.gamma * vtarg) assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf1_loss = self.qf_criterion(q1, qtarg.detach()) qf2_loss = self.qf_criterion(q2, qtarg.detach()) # vf loss q1_outs = self.qf1(batch['obs'], new_ac) q1_new = q1_outs.value q2_new = self.qf2(batch['obs'], new_ac).value q = torch.min(q1_new, q2_new) if self.discrete: vtarg = q + alpha * ent else: vtarg = q - alpha * logp assert v.shape == vtarg.shape vf_loss = self.vf_criterion(v, vtarg.detach()) # pi loss pi_loss = None if self.t % self.policy_update_period == 0: if self.discrete: target_dist = CatDist(logits=q1_outs.qvals.detach()) pi_dist = CatDist(logits=alpha * pi_out.dist.logits) pi_loss = pi_dist.kl(target_dist).mean() else: if self.rsample: assert q.shape == logp.shape pi_loss = (alpha*logp - q1_new).mean() else: pi_targ = q1_new - v assert pi_targ.shape == logp.shape pi_loss = (logp * (alpha * logp - pi_targ).detach()).mean() pi_loss += self.policy_mean_reg_weight * ( pi_out.dist.normal.mean**2).mean() # log pi loss about as frequently as other losses if self.t % self.log_period < self.policy_update_period: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) if self.t % self.log_period < self.update_period: if self.automatic_entropy_tuning: logger.add_scalar('ent/log_alpha', self.log_alpha.detach().cpu().numpy(), self.t, time.time()) if self.discrete: scalars = {"target": self.target_entropy, "entropy": ent.mean().detach().cpu().numpy().item()} else: scalars = {"target": self.target_entropy, "entropy": -torch.mean( logp.detach()).cpu().numpy().item()} logger.add_scalars('ent/entropy', scalars, self.t, time.time()) else: if self.discrete: logger.add_scalar( 'ent/entropy', ent.mean().detach().cpu().numpy().item(), self.t, time.time()) else: logger.add_scalar( 'ent/entropy', -torch.mean(logp.detach()).cpu().numpy().item(), self.t, time.time()) logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time()) logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time()) logger.add_scalar('loss/vf', vf_loss, self.t, time.time()) return pi_loss, qf1_loss, qf2_loss, vf_loss
def forward(self, ob): logits = np.random.rand(ob.shape[0], 2) v = np.random.rand(ob.shape[0], 1) return CatDist(torch.from_numpy(logits)), v