def train(traj, pol, qfs, targ_qfs, log_alpha, optim_pol, optim_qfs, optim_alpha, epoch, batch_size, # optimization hypers tau, gamma, sampling, reparam=True, log_enable=True, max_grad_norm=0.5, ): """ Train function for soft actor critic. Parameters ---------- traj : Traj Off policy trajectory. pol : Pol Policy. qfs : list of SAVfunction Q function. targ_qfs : list of SAVfunction Target Q function. log_alpha : torch.Tensor Temperature parameter of entropy. optim_pol : torch.optim.Optimizer Optimizer for Policy. optim_qfs : list of torch.optim.Optimizer Optimizer for Q function. optim_alpha : torch.optim.Optimizer Optimizer for alpha. epoch : int Number of iteration. batch_size : int Number of batches. tau : float Target updating rate. gamma : float Discounting rate. sampling : int Number of samping in calculating expectation. reparam : bool log_enable: bool If True, enable logging max_grad_norm : float Maximum gradient norm. Returns ------- result_dict : dict Dictionary which contains losses information. """ pol_losses = [] _qf_losses = [] alpha_losses = [] if log_enable: logger.log("Optimizing...") for batch in traj.random_batch(batch_size, epoch): pol_loss, qf_losses, alpha_loss = lf.sac( pol, qfs, targ_qfs, log_alpha, batch, gamma, sampling, reparam) optim_pol.zero_grad() pol_loss.backward() torch.nn.utils.clip_grad_norm_(pol.parameters(), max_grad_norm) optim_pol.step() for qf, optim_qf, qf_loss in zip(qfs, optim_qfs, qf_losses): optim_qf.zero_grad() qf_loss.backward() torch.nn.utils.clip_grad_norm_(qf.parameters(), max_grad_norm) optim_qf.step() optim_alpha.zero_grad() alpha_loss.backward() optim_alpha.step() for qf, targ_qf in zip(qfs, targ_qfs): for q, targ_q in zip(qf.parameters(), targ_qf.parameters()): targ_q.detach().copy_((1 - tau) * targ_q.detach() + tau * q.detach()) pol_losses.append(pol_loss.detach().cpu().numpy()) _qf_losses.append( (sum(qf_losses) / len(qf_losses)).detach().cpu().numpy()) alpha_losses.append(alpha_loss.detach().cpu().numpy()) if log_enable: logger.log("Optimization finished!") return dict( PolLoss=pol_losses, QfLoss=_qf_losses, AlphaLoss=alpha_losses )
def train( traj, pol, qfs, targ_qfs, log_alpha, optim_pol, optim_qfs, optim_alpha, epoch, batch_size, # optimization hypers tau, gamma, sampling, discrim, num_skill, reparam=True): """ Train function for soft actor critic. Parameters ---------- traj : Traj Off policy trajectory. pol : Pol Policy. qfs : list of SAVfunction Q function. targ_qfs : list of SAVfunction Target Q function. log_alpha : torch.Tensor Temperature parameter of entropy. optim_pol : torch.optim.Optimizer Optimizer for Policy. optim_qfs : list of torch.optim.Optimizer Optimizer for Q function. optim_alpha : torch.optim.Optimizer Optimizer for alpha. epoch : int Number of iteration. batch_size : int Number of batches. tau : float Target updating rate. gamma : float Discounting rate. sampling : int Number of samping in calculating expectation. reparam : bool discrim : SVfunction Discriminator. discrim_f : function Feature extractor of discriminator. f_dim : The dimention of discrim_f output. num_skill : int The number of skills. Returns ------- result_dict : dict Dictionary which contains losses information. """ pol_losses = [] _qf_losses = [] alpha_losses = [] logger.log("Optimizing...") for batch in traj.random_batch(batch_size, epoch): with torch.no_grad(): rews, info = calc_rewards(batch['obs'], num_skill, discrim) batch['rews'] = rews pol_loss, qf_losses, alpha_loss = lf.sac(pol, qfs, targ_qfs, log_alpha, batch, gamma, sampling, reparam) optim_pol.zero_grad() pol_loss.backward() optim_pol.step() for optim_qf, qf_loss in zip(optim_qfs, qf_losses): optim_qf.zero_grad() qf_loss.backward() optim_qf.step() optim_alpha.zero_grad() alpha_loss.backward() optim_alpha.step() for qf, targ_qf in zip(qfs, targ_qfs): for q, targ_q in zip(qf.parameters(), targ_qf.parameters()): targ_q.detach().copy_((1 - tau) * targ_q.detach() + tau * q.detach()) pol_losses.append(pol_loss.detach().cpu().numpy()) _qf_losses.append( (sum(qf_losses) / len(qf_losses)).detach().cpu().numpy()) alpha_losses.append(alpha_loss.detach().cpu().numpy()) logger.log("Optimization finished!") return dict(PolLoss=pol_losses, QfLoss=_qf_losses, AlphaLoss=alpha_losses)