def train( traj, pol, targ_pol, qf, targ_qf, optim_pol, optim_qf, epoch, batch_size, # optimization hypers tau, gamma, log_enable=True, ): pol_losses = [] qf_losses = [] if log_enable: logger.log("Optimizing...") for batch, indices in traj.prioritized_random_batch(batch_size, epoch, return_indices=True): qf_bellman_loss = lf.bellman(qf, targ_qf, targ_pol, batch, gamma, reduction='none') td_loss = torch.sqrt(qf_bellman_loss * 2) qf_bellman_loss = torch.mean(qf_bellman_loss) optim_qf.zero_grad() qf_bellman_loss.backward() optim_qf.step() pol_loss = lf.ag(pol, qf, batch) optim_pol.zero_grad() pol_loss.backward() optim_pol.step() for p, targ_p in zip(pol.parameters(), targ_pol.parameters()): targ_p.detach().copy_((1 - tau) * targ_p.detach() + tau * p.detach()) for q, targ_q in zip(qf.parameters(), targ_qf.parameters()): targ_q.detach().copy_((1 - tau) * targ_q.detach() + tau * q.detach()) qf_losses.append(qf_bellman_loss.detach().cpu().numpy()) pol_losses.append(pol_loss.detach().cpu().numpy()) traj = tf.update_pris(traj, td_loss, indices) if log_enable: logger.log("Optimization finished!") return {'PolLoss': pol_losses, 'QfLoss': qf_losses}
def compute_pris(data, qf, targ_qf, pol, gamma, continuous=True, deterministic=True, rnn=False, sampling=1, alpha=0.6, epsilon=1e-6): if continuous: epis = data.current_epis for epi in epis: data_map = dict() keys = ['obs', 'acs', 'rews', 'next_obs', 'dones'] for key in keys: data_map[key] = torch.tensor(epi[key], device=get_device()) if rnn: qf.reset() targ_qf.reset() pol.reset() keys = ['obs', 'acs', 'next_obs'] for key in keys: data_map[key] = data_map[key].unsqueeze(1) with torch.no_grad(): bellman_loss = lf.bellman(qf, targ_qf, pol, data_map, gamma, continuous, deterministic, sampling, reduction='none') td_loss = torch.sqrt(bellman_loss * 2) pris = (torch.abs(td_loss) + epsilon)**alpha epi['pris'] = pris.cpu().numpy() return data else: raise NotImplementedError( "Only Q function with continuous action space is supported now.")
def train( traj, pol, targ_pol, qf, targ_qf, optim_pol, optim_qf, epoch, batch_size, # optimization hypers tau, gamma # advantage estimation ): """ Train function for deep deterministic policy gradient Parameters ---------- traj : Traj Off policy trajectory. pol : Pol Policy. targ_pol : Pol Target Policy. qf : SAVfunction Q function. targ_qf : SAVfunction Target Q function. optim_pol : torch.optim.Optimizer Optimizer for Policy. optim_qf : torch.optim.Optimizer Optimizer for Q function. epoch : int Number of iteration. batch_size : int Number of batches. tau : float Target updating rate. gamma : float Discounting rate. Returns ------- result_dict : dict Dictionary which contains losses information. """ pol_losses = [] qf_losses = [] logger.log("Optimizing...") for batch in traj.random_batch(batch_size, epoch): qf_bellman_loss = lf.bellman(qf, targ_qf, targ_pol, batch, gamma) optim_qf.zero_grad() qf_bellman_loss.backward() optim_qf.step() pol_loss = lf.ag(pol, qf, batch, no_noise=True) optim_pol.zero_grad() pol_loss.backward() optim_pol.step() for p, targ_p in zip(pol.parameters(), targ_pol.parameters()): targ_p.detach().copy_((1 - tau) * targ_p.detach() + tau * p.detach()) for q, targ_q in zip(qf.parameters(), targ_qf.parameters()): targ_q.detach().copy_((1 - tau) * targ_q.detach() + tau * q.detach()) qf_losses.append(qf_bellman_loss.detach().cpu().numpy()) pol_losses.append(pol_loss.detach().cpu().numpy()) logger.log("Optimization finished!") return {'PolLoss': pol_losses, 'QfLoss': qf_losses}
def train( traj, pol, targ_pol, qf, targ_qf, optim_pol, optim_qf, epoch, batch_size, # optimization hypers tau, gamma, # advantage estimation sampling, log_enable=True, ): """ Train function for deep deterministic policy gradient Parameters ---------- traj : Traj Off policy trajectory. pol : Pol Policy. targ_pol : Pol Target Policy. qf : SAVfunction Q function. targ_qf : SAVfunction Target Q function. optim_pol : torch.optim.Optimizer Optimizer for Policy. optim_qf : torch.optim.Optimizer Optimizer for Q function. epoch : int Number of iteration. batch_size : int Number of batches. tau : float Target updating rate. gamma : float Discounting rate. sampling : int Number of samping in calculating expectation. log_enable: bool If True, enable logging Returns ------- result_dict : dict Dictionary which contains losses information. """ pol_losses = [] qf_losses = [] if log_enable: logger.log("Optimizing...") for batch in traj.iterate(batch_size, epoch): qf_bellman_loss = lf.bellman(qf, targ_qf, targ_pol, batch, gamma, sampling=sampling) optim_qf.zero_grad() qf_bellman_loss.backward() optim_qf.step() pol_loss = lf.ag(pol, qf, batch, sampling) optim_pol.zero_grad() pol_loss.backward() optim_pol.step() for q, targ_q, p, targ_p in zip(qf.parameters(), targ_qf.parameters(), pol.parameters(), targ_pol.parameters()): targ_p.detach().copy_((1 - tau) * targ_p.detach() + tau * p.detach()) targ_q.detach().copy_((1 - tau) * targ_q.detach() + tau * q.detach()) qf_losses.append(qf_bellman_loss.detach().cpu().numpy()) pol_losses.append(pol_loss.detach().cpu().numpy()) if log_enable: logger.log("Optimization finished!") return dict( PolLoss=pol_losses, QfLoss=qf_losses, )
def compute_pris(data, qf, targ_qf, pol, gamma, continuous=True, deterministic=True, rnn=False, sampling=1, alpha=0.6, epsilon=1e-6): """ Compute prioritization. Parameters ---------- data : Traj or epis(dict of ndarray) qf : SAVfunction targ_qf : SAVfunction pol : Pol gamma : float continuous : bool deterministic : bool rnn : bool sampling : int alpha : float epsilen : float Returns ------- data : Traj or epi(dict of ndarray) Corresponding to input """ if continuous: if isinstance(data, Traj): epis = data.current_epis else: epis = data for epi in epis: data_map = dict() keys = ['obs', 'acs', 'rews', 'next_obs', 'dones'] for key in keys: data_map[key] = torch.tensor(epi[key], device=get_device()) if rnn: qf.reset() targ_qf.reset() pol.reset() keys = ['obs', 'acs', 'next_obs'] for key in keys: data_map[key] = data_map[key].unsqueeze(1) with torch.no_grad(): bellman_loss = lf.bellman(qf, targ_qf, pol, data_map, gamma, continuous, deterministic, sampling, reduction='none') td_loss = torch.sqrt(bellman_loss * 2) pris = (torch.abs(td_loss) + epsilon)**alpha epi['pris'] = pris.cpu().numpy() return data else: raise NotImplementedError( "Only Q function with continuous action space is supported now.")