def evaluate(self): """Evaluate.""" eval_env = VecEpsilonGreedy(VecFrameStack(self.env, self.frame_stack), self.eval_eps) self.qf.eval() misc.set_env_to_eval_mode(eval_env) # Eval policy os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True) outfile = os.path.join(self.logdir, 'eval', self.ckptr.format.format(self.t) + '.json') stats = rl_evaluate(eval_env, self.qf, self.eval_num_episodes, outfile, self.device) logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'], self.t, time.time()) logger.add_scalar('eval/mean_episode_length', stats['mean_length'], self.t, time.time()) # Record policy os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True) outfile = os.path.join(self.logdir, 'video', self.ckptr.format.format(self.t) + '.mp4') rl_record(eval_env, self.qf, self.record_num_episodes, outfile, self.device) self.qf.train() misc.set_env_to_train_mode(self.env) self.data_manager.manual_reset()
def __call__(self, ob, state_in=None): """Produce decision from model.""" if self.t < self.policy_training_start: outs = self.pi(ob, state_in, deterministic=True) else: outs = self.pi(ob, state_in) def _res_norm(ac): return ac.abs().sum(dim=1).mean() residual_norm = nest.map_structure(_res_norm, outs.action) if isinstance(residual_norm, torch.Tensor): logger.add_scalar('actor/l1_residual_norm', residual_norm, self.t, time.time()) self.t += outs.action.shape[0] else: self.t += nest.flatten(outs.action)[0].shape[0] for k, v in residual_norm.items(): logger.add_scalar(f'actor/{k}_residual_norm', v, self.t, time.time()) data = {'action': outs.action, 'value': self.vf(ob).value, 'logp': outs.dist.log_prob(outs.action), 'dist': outs.dist.to_tensors()} if outs.state_out: data['state'] = outs.state_out return data
def evaluate(self): """Evaluate model.""" self.pi.eval() misc.set_env_to_eval_mode(self.env) # Eval policy os.makedirs(os.path.join(self.logdir, 'eval'), exist_ok=True) outfile = os.path.join(self.logdir, 'eval', self.ckptr.format.format(self.t) + '.json') stats = rl_evaluate(self.env, self.pi, self.eval_num_episodes, outfile, self.device) logger.add_scalar('eval/mean_episode_reward', stats['mean_reward'], self.t, time.time()) logger.add_scalar('eval/mean_episode_length', stats['mean_length'], self.t, time.time()) # Record policy # os.makedirs(os.path.join(self.logdir, 'video'), exist_ok=True) # outfile = os.path.join(self.logdir, 'video', # self.ckptr.format.format(self.t) + '.mp4') # rl_record(self.env, self.pi, self.record_num_episodes, outfile, # self.device) self.pi.train() misc.set_env_to_train_mode(self.env)
def loss(self, batch): """Loss function.""" # compute QFunction loss. with torch.no_grad(): target_action = self.target_pi(batch['next_obs']).action noise = (torch.randn_like(target_action) * self.policy_noise).clamp(-self.policy_noise_clip, self.policy_noise_clip) target_action = (target_action + noise).clamp(-1., 1.) target_q1 = self.target_qf1(batch['next_obs'], target_action).value target_q2 = self.target_qf2(batch['next_obs'], target_action).value target_q = torch.min(target_q1, target_q2) qtarg = self.reward_scale * batch['reward'].float() + ( (1.0 - batch['done']) * self.gamma * target_q) q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf_loss = self.qf_criterion(q1, qtarg) + self.qf_criterion(q2, qtarg) # compute policy loss if self.t % self.policy_update_period == 0: action = self.pi(batch['obs'], deterministic=True).action q = self.qf1(batch['obs'], action).value pi_loss = -q.mean() else: pi_loss = torch.zeros_like(qf_loss) # log losses if self.t % self.log_period < self.update_period: logger.add_scalar('loss/qf', qf_loss, self.t, time.time()) if self.t % self.policy_update_period == 0: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) return pi_loss, qf_loss
def step(self): # Get batch. if self._diter is None: self._diter = self.dtrain.__iter__() try: batch = self._diter.__next__() except StopIteration: self.epochs += 1 self._diter = None return self.epochs batch = nest.map_structure(lambda x: x.to(self.device), batch) # compute loss ob, ac = batch self.model.train() loss = -self.model(ob).log_prob(ac).mean() logger.add_scalar('train/loss', loss.detach().cpu().numpy(), self.t, time.time()) # update model self.opt.zero_grad() loss.backward() self.opt.step() # increment step self.t += min( len(self.data) - (self.t % len(self.data)), self.batch_size) return self.epochs
def log_stats(): from dl import logger cpu_util, mem_util, gpus = get_stats() timestamp = time.time() logger.add_scalar('hardware/cpu_util', cpu_util, walltime=timestamp) logger.add_scalar('hardware/mem_util', mem_util, walltime=timestamp) logger.add_scalar('hardware/cpu_util', cpu_util, walltime=timestamp) for gpu in gpus: logger.add_scalar(f'hardware/gpu{gpu.id}/util', gpu.util, walltime=timestamp) logger.add_scalar(f'hardware/gpu{gpu.id}/mem_util', gpu.memutil, walltime=timestamp)
def evaluate(self): """Evaluate model.""" self.model.eval() accuracy = [] with torch.no_grad(): for batch in self.dtest: x, y = nest.map_structure(lambda x: x.to(self.device), batch) y_hat = self.model(x).argmax(-1) accuracy.append((y_hat == y).float().mean().cpu().numpy()) logger.add_scalar(f'test_accuracy', np.mean(accuracy), self.epochs, time.time())
def step(self, action): """Step.""" obs, rews, dones, infos = self.venv.step(action) if not self._eval: self.t += np.sum(np.logical_not(self._dones)) for i, d in enumerate(self._dones): # handle synced resets if not d: self.lens[i] += 1 self.rews[i] += rews[i] else: assert dones[i] for i, done in enumerate(dones): if done and not self._dones[i]: if not self._eval: logger.add_scalar('env/episode_length', self.lens[i], self.t, time.time()) logger.add_scalar('env/episode_reward', self.rews[i], self.t, time.time()) self.lens[i] = 0 self.rews[i] = 0. # log unwrapped episode stats if they exist if 'episode_info' in infos[0]: for i, info in enumerate(infos): epinfo = info['episode_info'] if epinfo['done'] and not self._eval and not self._dones[i]: logger.add_scalar('env/unwrapped_episode_length', epinfo['length'], self.t, time.time()) logger.add_scalar('env/unwrapped_episode_reward', epinfo['reward'], self.t, time.time()) self._dones = np.logical_or(dones, self._dones) return obs, rews, dones, infos
def loss(self, batch): """Loss.""" q = self.qf(batch['obs'], batch['action']).value with torch.no_grad(): target = self._compute_target(batch['reward'], batch['next_obs'], batch['done']) assert target.shape == q.shape err = self.criterion(target, q) self.buffer.update_priorities(batch['idxes'], err.detach().cpu().numpy() + 1e-6) assert err.shape == batch['weights'].shape err = batch['weights'] * err loss = err.mean() if self.t % self.log_period < self.update_period: logger.add_scalar('alg/maxq', torch.max(q).detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/loss', loss.detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/epsilon', self.eps_schedule.value(self._actor.t), self.t, time.time()) logger.add_scalar('alg/beta', self.beta_schedule.value(self.t), self.t, time.time()) return loss
def evaluate(self): """Evaluate model.""" self.model.eval() nll = 0. with torch.no_grad(): for batch in self.dtrain: ob, ac = nest.map_structure(lambda x: x.to(self.device), batch) nll += -self.model(ob).log_prob(ac).sum() avg_nll = nll / len(self.data) logger.add_scalar('train/NLL', nll, self.epochs, time.time()) logger.add_scalar('train/AVG_NLL', avg_nll, self.epochs, time.time())
def __call__(self, obs): """Act.""" self.t += nest.flatten(obs)[0].shape[0] if self.should_take_zero_action(): if self.zero_action is None: with torch.no_grad(): self.zero_action = nest.map_structure( torch.zeros_like, self.pi(obs).action) return {'action': self.zero_action} else: ac = self.pi(obs).action with torch.no_grad(): ac_norm = ac.abs().mean().cpu().numpy() logger.add_scalar('alg/residual_norm', ac_norm, self.t, time.time()) return {'action': self.pi(obs).action}
def step(self): """Step alpha zero.""" self.pi.train() self.t += self.data_manager.play_game(self.n_sims) # fill replay buffer if needed while not self.buffer.full(): self.t += self.data_manager.play_game(self.n_sims) for _ in range(self.batches_per_game): batch = self.data_manager.sample(self.batch_size) self.opt.zero_grad() loss = self.loss(batch) loss['total'].backward() self.opt.step() for k, v in loss.items(): logger.add_scalar(f'loss/{k}', v.detach().cpu().numpy(), self.t, time.time()) return self.t
def __call__(self, ob, state_in=None, mask=None): """Produce decision from model.""" if self.t < self.policy_training_start: outs = self.pi(ob, state_in, mask, deterministic=True) if not torch.allclose(outs.action, torch.zeros_like(outs.action)): raise ValueError("Pi should be initialized to output zero " "actions so that an acurate value function " "can be learned for the base policy.") else: outs = self.pi(ob, state_in, mask) residual_norm = torch.mean(torch.sum(torch.abs(outs.action), dim=1)) logger.add_scalar('actor/l1_residual_norm', residual_norm, self.t, time.time()) self.t += outs.action.shape[0] data = { 'action': outs.action, 'value': outs.value, 'logp': outs.dist.log_prob(outs.action) } if outs.state_out: data['state'] = outs.state_out return data
def loss(self, batch): """Loss function.""" # compute QFunction loss. with torch.no_grad(): target_action = self.target_pi(batch['next_obs']).action target_q = self.target_qf(batch['next_obs'], target_action).value qtarg = self.reward_scale * batch['reward'].float() + ( (1.0 - batch['done']) * self.gamma * target_q) q = self.qf(batch['obs'], batch['action']).value assert qtarg.shape == q.shape qf_loss = self.qf_criterion(q, qtarg) # compute policy loss action = self.pi(batch['obs'], deterministic=True).action q = self.qf(batch['obs'], action).value pi_loss = -q.mean() # log losses if self.t % self.log_period < self.update_period: logger.add_scalar('loss/qf', qf_loss, self.t, time.time()) logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) return pi_loss, qf_loss
def loss(self, batch): """Compute loss.""" q = self.qf(batch['obs'], batch['action']).value with torch.no_grad(): target = self._compute_target(batch['reward'], batch['next_obs'], batch['done']) assert target.shape == q.shape loss = self.criterion(target, q).mean() if self.t % self.log_period < self.update_period: logger.add_scalar('alg/maxq', torch.max(q).detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/loss', loss.detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/epsilon', self.eps_schedule.value(self._actor.t), self.t, time.time()) return loss
def step(self): """Compute rollout, loss, and update model.""" self.pi.train() # adjust learning rate lr_frac = self.lr_decay_rate**(self.t // self.lr_decay_freq) for g in self.opt.param_groups: g['lr'] = self.pi_lr * lr_frac for g in self.opt_l.param_groups: g['lr'] = self.lambda_lr * lr_frac self.data_manager.rollout() self.t += self.data_manager.rollout_length * self.nenv losses = {} for _ in range(self.epochs_per_rollout): for batch in self.data_manager.sampler(): loss = self.loss(batch) if losses == {}: losses = {k: [] for k in loss} for k, v in loss.items(): losses[k].append(v.detach().cpu().numpy()) if self.t >= max(self.policy_training_start, self.lambda_training_start): self.opt_l.zero_grad() loss['lambda'].backward(retain_graph=True) self.opt_l.step() self.opt.zero_grad() loss['total'].backward() if self.max_grad_norm: nn.utils.clip_grad_norm_(self.pi.parameters(), self.max_grad_norm) self.opt.step() for k, v in losses.items(): logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time()) logger.add_scalar('alg/lr_pi', self.opt.param_groups[0]['lr'], self.t, time.time()) logger.add_scalar('alg/lr_lambda', self.opt_l.param_groups[0]['lr'], self.t, time.time()) return self.t
def step(self): """Compute rollout, loss, and update model.""" self.pi.train() self.t += self.data_manager.rollout() losses = {} for _ in range(self.epochs_per_rollout): for batch in self.data_manager.sampler(): self.opt.zero_grad() loss = self.loss(batch) if losses == {}: losses = {k: [] for k in loss} for k, v in loss.items(): losses[k].append(v.detach().cpu().numpy()) loss['total'].backward() if self.max_grad_norm: norm = nn.utils.clip_grad_norm_(self.pi.parameters(), self.max_grad_norm) logger.add_scalar('alg/grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt.step() for k, v in losses.items(): logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time()) data = self.data_manager.storage.get_rollout() value_error = data['vpred'].data - data['q_mc'].data logger.add_scalar('alg/value_error_mean', value_error.mean().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/value_error_std', value_error.std().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/kl', self.compute_kl(), self.t, time.time()) return self.t
def loss(self, batch): """Compute loss.""" if self.data_manager.recurrent: outs = self.pi(batch['obs'], batch['state'], batch['mask']) else: outs = self.pi(batch['obs']) loss = {} # compute policy loss if self.t < self.policy_training_start: pi_loss = torch.Tensor([0.0]).to(self.device) else: logp = outs.dist.log_prob(batch['action']) assert logp.shape == batch['logp'].shape ratio = torch.exp(logp - batch['logp']) assert ratio.shape == batch['atarg'].shape ploss1 = ratio * batch['atarg'] ploss2 = torch.clamp(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) * batch['atarg'] pi_loss = -torch.min(ploss1, ploss2).mean() loss['pi'] = pi_loss # compute value loss vloss1 = 0.5 * self.mse(outs.value, batch['vtarg']) vpred_clipped = batch['vpred'] + (outs.value - batch['vpred']).clamp( -self.clip_param, self.clip_param) vloss2 = 0.5 * self.mse(vpred_clipped, batch['vtarg']) vf_loss = torch.max(vloss1, vloss2).mean() loss['value'] = vf_loss # compute entropy loss if self.t < self.policy_training_start: ent_loss = torch.Tensor([0.0]).to(self.device) else: ent_loss = outs.dist.entropy().mean() loss['entropy'] = ent_loss # compute residual regularizer if self.t < self.policy_training_start: reg_loss = torch.Tensor([0.0]).to(self.device) else: if self.l2_reg: reg_loss = outs.dist.rsample().pow(2).sum(dim=-1).mean() else: # huber loss ac_norm = torch.norm(outs.dist.rsample(), dim=-1) reg_loss = self.huber(ac_norm, torch.zeros_like(ac_norm)) loss['reg'] = reg_loss ############################### # Constrained loss added here. ############################### # soft plus on lambda to constrain it to be positive. lambda_ = F.softplus(self.log_lambda_) logger.add_scalar('alg/lambda', lambda_, self.t, time.time()) logger.add_scalar('alg/lambda_', self.log_lambda_, self.t, time.time()) if self.t < max(self.policy_training_start, self.lambda_training_start): loss['lambda'] = torch.Tensor([0.0]).to(self.device) else: neps = (1.0 - batch['mask']).sum() loss['lambda'] = ( lambda_ * (batch['reward'].sum() - self.reward_threshold * neps) / batch['reward'].size()[0]) if self.t >= self.policy_training_start: loss['pi'] = (reg_loss + lambda_ * loss['pi']) / (1. + lambda_) loss['total'] = (loss['pi'] + self.vf_coef * vf_loss - self.ent_coef * ent_loss) return loss
def loss(self, batch): """Loss function.""" pi_out = self.pi(batch['obs'], reparameterization_trick=True) logp = pi_out.dist.log_prob(pi_out.action) q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value # alpha loss if self.automatic_entropy_tuning: ent_error = logp + self.target_entropy alpha_loss = -(self.log_alpha * ent_error.detach()).mean() self.opt_alpha.zero_grad() alpha_loss.backward() self.opt_alpha.step() alpha = self.log_alpha.exp() else: alpha = self.alpha alpha_loss = 0 # qf loss with torch.no_grad(): next_pi_out = self.pi(batch['next_obs']) next_ac_logp = next_pi_out.dist.log_prob(next_pi_out.action) q1_next = self.target_qf1(batch['next_obs'], next_pi_out.action).value q2_next = self.target_qf2(batch['next_obs'], next_pi_out.action).value qnext = torch.min(q1_next, q2_next) - alpha * next_ac_logp qtarg = batch['reward'] + (1.0 - batch['done']) * self.gamma * qnext assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf1_loss = self.mse_loss(q1, qtarg) qf2_loss = self.mse_loss(q2, qtarg) # pi loss pi_loss = None if self.t % self.policy_update_period == 0: q1_pi = self.qf1(batch['obs'], pi_out.action).value q2_pi = self.qf2(batch['obs'], pi_out.action).value min_q_pi = torch.min(q1_pi, q2_pi) assert min_q_pi.shape == logp.shape pi_loss = (alpha * logp - min_q_pi).mean() # log pi loss about as frequently as other losses if self.t % self.log_period < self.policy_update_period: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) if self.t % self.log_period < self.update_period: if self.automatic_entropy_tuning: logger.add_scalar('alg/log_alpha', self.log_alpha.detach().cpu().numpy(), self.t, time.time()) scalars = { "target": self.target_entropy, "entropy": -torch.mean(logp.detach()).cpu().numpy().item() } logger.add_scalars('alg/entropy', scalars, self.t, time.time()) else: logger.add_scalar( 'alg/entropy', -torch.mean(logp.detach()).cpu().numpy().item(), self.t, time.time()) logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time()) logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time()) logger.add_scalar('alg/qf1', q1.mean().detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/qf2', q2.mean().detach().cpu().numpy(), self.t, time.time()) return pi_loss, qf1_loss, qf2_loss
def step(self): """Compute rollout, loss, and update model.""" self.pi.train() self.t += self.data_manager.rollout() losses = {'pi': [], 'vf': [], 'ent': [], 'kl': [], 'total': [], 'kl_pen': []} ####################### # Update pi ####################### if self.t >= self.policy_training_start: kl_too_big = False for _ in range(self.epochs_pi): if kl_too_big: break for batch in self.data_manager.sampler(): self.opt_pi.zero_grad() loss = self.loss_pi(batch) # break if new policy is too different from old policy if loss['kl'] > 4 * self.kl_target: kl_too_big = True break loss['total'].backward() for k, v in loss.items(): losses[k].append(v.detach().cpu().numpy()) if self.max_grad_norm: norm = nn.utils.clip_grad_norm_(self.pi.parameters(), self.max_grad_norm) logger.add_scalar('alg/grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt_pi.step() ####################### # Update value function ####################### for _ in range(self.epochs_vf): for batch in self.data_manager.sampler(): self.opt_vf.zero_grad() loss = self.loss_vf(batch) losses['vf'].append(loss.detach().cpu().numpy()) loss.backward() if self.max_grad_norm: norm = nn.utils.clip_grad_norm_(self.vf.parameters(), self.max_grad_norm) logger.add_scalar('alg/vf_grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/vf_grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt_vf.step() for k, v in losses.items(): if len(v) > 0: logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time()) # update weight on kl to match kl_target. if self.t >= self.policy_training_start: kl = self.compute_kl() if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight: self.kl_weight = self.initial_kl_weight elif kl > 1.3 * self.kl_target: self.kl_weight *= self.alpha elif kl < 0.7 * self.kl_target: self.kl_weight /= self.alpha logger.add_scalar('alg/kl', kl, self.t, time.time()) logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time()) data = self.data_manager.storage.get_rollout() value_error = data['vpred'].data - data['q_mc'].data logger.add_scalar('alg/value_error_mean', value_error.mean().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/value_error_std', value_error.std().cpu().numpy(), self.t, time.time()) return self.t
def loss(self, batch): """Loss function.""" dist = self.pi(batch['obs']).dist q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value # alpha loss if self.automatic_entropy_tuning: ent_error = dist.entropy() - self.target_entropy alpha_loss = self.log_alpha * ent_error.detach().mean() self.opt_alpha.zero_grad() alpha_loss.backward() self.opt_alpha.step() alpha = self.log_alpha.exp() else: alpha = self.alpha alpha_loss = 0 # qf loss with torch.no_grad(): next_dist = self.pi(batch['next_obs']).dist q1_next = self.target_qf1(batch['next_obs']).qvals q2_next = self.target_qf2(batch['next_obs']).qvals qmin = torch.min(q1_next, q2_next) # explicitly compute the expectation over next actions qnext = torch.sum(qmin * next_dist.probs, dim=1) + alpha * next_dist.entropy() qtarg = batch['reward'] + (1.0 - batch['done']) * self.gamma * qnext assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf1_loss = self.mse_loss(q1, qtarg) qf2_loss = self.mse_loss(q2, qtarg) # pi loss pi_loss = None if self.t % self.policy_update_period == 0: with torch.no_grad(): q1_pi = self.qf1(batch['obs']).qvals q2_pi = self.qf2(batch['obs']).qvals min_q_pi = torch.min(q1_pi, q2_pi) assert min_q_pi.shape == dist.logits.shape target_dist = CatDist(logits=min_q_pi) pi_dist = CatDist(logits=alpha * dist.logits) pi_loss = pi_dist.kl(target_dist).mean() # log pi loss about as frequently as other losses if self.t % self.log_period < self.policy_update_period: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) if self.t % self.log_period < self.update_period: if self.automatic_entropy_tuning: logger.add_scalar('alg/log_alpha', self.log_alpha.detach().cpu().numpy(), self.t, time.time()) scalars = { "target": self.target_entropy, "entropy": dist.entropy().mean().detach().cpu().numpy().item() } logger.add_scalars('alg/entropy', scalars, self.t, time.time()) else: logger.add_scalar( 'alg/entropy', dist.entropy().mean().detach().cpu().numpy().item(), self.t, time.time()) logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time()) logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time()) logger.add_scalar('alg/qf1', q1.mean().detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/qf2', q2.mean().detach().cpu().numpy(), self.t, time.time()) return pi_loss, qf1_loss, qf2_loss
def loss(self, batch): """Loss function.""" pi_out = self.pi(batch['obs'], reparameterization_trick=True) logp = pi_out.dist.log_prob(pi_out.action) q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value # alpha loss should_update_policy = (self.t >= self.policy_training_start and self.t % self.policy_update_period == 0) if self.automatic_entropy_tuning: if should_update_policy: ent_error = logp + self.target_entropy alpha_loss = -(self.log_alpha * ent_error.detach()).mean() self.opt_alpha.zero_grad() alpha_loss.backward() self.opt_alpha.step() alpha = self.log_alpha.exp() else: alpha = self.alpha alpha_loss = 0 # qf loss with torch.no_grad(): next_pi_out = self.pi(batch['next_obs']) next_ac = next_pi_out.action # Account for the fact that we are learning about the base policy # before we start updating the residual policy if self.t < self.policy_training_start: next_ac = nest.map_structure(torch.zeros_like, next_ac) next_ac_logp = next_pi_out.dist.log_prob(next_ac) q1_next = self.target_qf1(batch['next_obs'], next_ac).value q2_next = self.target_qf2(batch['next_obs'], next_ac).value qnext = torch.min(q1_next, q2_next) - alpha * next_ac_logp qtarg = batch['reward'] + (1.0 - batch['done']) * self.gamma * qnext assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf1_loss = self.mse_loss(q1, qtarg) + self.q_reg_weight * (q1**2).mean() qf2_loss = self.mse_loss(q2, qtarg) + self.q_reg_weight * (q2**2).mean() # pi loss pi_loss = None if should_update_policy: q1_pi = self.qf1(batch['obs'], pi_out.action).value q2_pi = self.qf2(batch['obs'], pi_out.action).value min_q_pi = torch.min(q1_pi, q2_pi) assert min_q_pi.shape == logp.shape pi_loss = (alpha * logp - min_q_pi).mean() action_reg = self.action_reg_weight * (pi_out.action**2).mean() pi_loss = pi_loss + action_reg # log pi loss about as frequently as other losses if self.t % self.log_period < self.policy_update_period: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) if self.t % self.log_period < self.update_period: if self.automatic_entropy_tuning: logger.add_scalar('alg/log_alpha', self.log_alpha.detach().cpu().numpy(), self.t, time.time()) scalars = { "target": self.target_entropy, "entropy": -torch.mean(logp.detach()).cpu().numpy().item() } logger.add_scalars('alg/entropy', scalars, self.t, time.time()) else: logger.add_scalar( 'alg/entropy', -torch.mean(logp.detach()).cpu().numpy().item(), self.t, time.time()) logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time()) logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time()) logger.add_scalar('alg/qf1', q1.mean().detach().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/qf2', q2.mean().detach().cpu().numpy(), self.t, time.time()) return pi_loss, qf1_loss, qf2_loss
def step(self): """Compute rollout, loss, and update model.""" self.pi.train() self.t += self.data_manager.rollout() losses = {'pi': [], 'vf': [], 'ent': [], 'kl': [], 'total': [], 'kl_pen': [], 'rnd': []} if self.norm_advantages: atarg = self.data_manager.storage.data['atarg'] atarg = atarg[:, 0] + self.rnd_coef * atarg[:, 1] self.data_manager.storage.data['atarg'][:, 0] -= atarg.mean() self.data_manager.storage.data['atarg'] /= atarg.std() + 1e-5 ####################### # Update pi ####################### kl_too_big = False for _ in range(self.epochs_pi): if kl_too_big: break for batch in self.data_manager.sampler(): self.opt_pi.zero_grad() loss = self.loss_pi(batch) # break if new policy is too different from old policy if loss['kl'] > 4 * self.kl_target: kl_too_big = True break loss['total'].backward() for k, v in loss.items(): losses[k].append(v.detach().cpu().numpy()) if self.max_grad_norm: norm = nn.utils.clip_grad_norm_(self.pi.parameters(), self.max_grad_norm) logger.add_scalar('alg/grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt_pi.step() ####################### # Update value function ####################### for _ in range(self.epochs_vf): for batch in self.data_manager.sampler(): self.opt_vf.zero_grad() loss = self.loss_vf(batch) losses['vf'].append(loss.detach().cpu().numpy()) loss.backward() if self.max_grad_norm: norm = nn.utils.clip_grad_norm_(self.vf.parameters(), self.max_grad_norm) logger.add_scalar('alg/vf_grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/vf_grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt_vf.step() ####################### # Update RND ####################### for batch in self.data_manager.sampler(): self.rnd_update_count += 1 if self.rnd_update_count % self.rnd_subsample_rate == 0: loss = self.rnd.update(batch['obs']) losses['rnd'].append(loss.detach().cpu().numpy()) for k, v in losses.items(): logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time()) # update weight on kl to match kl_target. kl = self.compute_kl() if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight: self.kl_weight = self.initial_kl_weight elif kl > 1.3 * self.kl_target: self.kl_weight *= self.alpha elif kl < 0.7 * self.kl_target: self.kl_weight /= self.alpha logger.add_scalar('alg/kl', kl, self.t, time.time()) logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time()) avg_return = self.data_manager.storage.data['return'][:, 0].mean(dim=0) avg_return = avg_return[0] + self.rnd_coef * avg_return[1] logger.add_scalar('alg/return', avg_return, self.t, time.time()) # log value errors errors = [] for batch in self.data_manager.sampler(): errors.append(batch['vpred'] - batch['q_mc']) errors = torch.cat(errors) logger.add_scalar('alg/value_error_mean', errors.mean().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/value_error_std', errors.std().cpu().numpy(), self.t, time.time()) return self.t
def step(self): """Compute rollout, loss, and update model.""" self.pi.train() self.t += self.data_manager.rollout() losses = { 'pi': [], 'vf': [], 'ent': [], 'kl': [], 'total': [], 'kl_pen': [], 'rnd': [], 'ide': [] } # if self.norm_advantages: # atarg = self.data_manager.storage.data['atarg'] # atarg = atarg[:, 0] + self.ngu_coef * atarg[:, 1] # self.data_manager.storage.data['atarg'][:, 0] -= atarg.mean() # self.data_manager.storage.data['atarg'] /= atarg.std() + 1e-5 if self.t >= self.policy_training_starts: ####################### # Update pi ####################### kl_too_big = False for _ in range(self.epochs_pi): if kl_too_big: break for batch in self.data_manager.sampler(): self.opt_pi.zero_grad() loss = self.loss_pi(batch) # break if new policy is too different from old policy if loss['kl'] > 4 * self.kl_target: kl_too_big = True break loss['total'].backward() for k, v in loss.items(): losses[k].append(v.detach().cpu().numpy()) if self.max_grad_norm: norm = nn.utils.clip_grad_norm_( self.pi.parameters(), self.max_grad_norm) logger.add_scalar('alg/grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt_pi.step() ####################### # Update value function ####################### for _ in range(self.epochs_vf): for batch in self.data_manager.sampler(): self.opt_vf.zero_grad() loss = self.loss_vf(batch) losses['vf'].append(loss.detach().cpu().numpy()) loss.backward() if self.max_grad_norm: norm = nn.utils.clip_grad_norm_( self.vf.parameters(), self.max_grad_norm) logger.add_scalar('alg/vf_grad_norm', norm, self.t, time.time()) logger.add_scalar('alg/vf_grad_norm_clipped', min(norm, self.max_grad_norm), self.t, time.time()) self.opt_vf.step() rollout = self.data_manager.storage.data lens = self.data_manager.storage.sequence_lengths.int() ####################### # Store rollout_data in replay_buffer ####################### for i in range(self.nenv): if self.buffer.num_in_buffer < self.buffer_size or i % self.ngu_subsample_freq == 0: for step in range(lens[i]): def _f(x): return x[step][i].cpu().numpy() data = nest.map_structure(_f, rollout) idx = self.buffer.store_observation(data['obs']) self.buffer.store_effect( idx, { 'done': data['done'], 'action': data['action'], 'reward': data['reward'] }) ####################### # Update NGU ####################### for _ in range(self.ngu_updates): batch = self.buffer.sample(self.ngu_batch_size) def _to_torch(data): if isinstance(data, np.ndarray): return torch.from_numpy(data).to(self.device) else: return data batch = nest.map_structure(_to_torch, batch) loss = self.ngu.update_rnd(batch['obs']['ob']) losses['rnd'].append(loss.detach().cpu().numpy()) not_done = torch.logical_not(batch['done']) loss = self.ngu.update_ide(batch['obs']['ob'][not_done], batch['next_obs']['ob'][not_done], batch['action'][not_done].long()) losses['ide'].append(loss.detach().cpu().numpy()) for k, v in losses.items(): if len(v) == 0: continue logger.add_scalar(f'loss/{k}', np.mean(v), self.t, time.time()) # update weight on kl to match kl_target. if self.t >= self.policy_training_starts: kl = self.compute_kl() if kl > 10.0 * self.kl_target and self.kl_weight < self.initial_kl_weight: self.kl_weight = self.initial_kl_weight elif kl > 1.3 * self.kl_target: self.kl_weight *= self.alpha elif kl < 0.7 * self.kl_target: self.kl_weight /= self.alpha else: kl = 0.0 logger.add_scalar('alg/kl', kl, self.t, time.time()) logger.add_scalar('alg/kl_weight', self.kl_weight, self.t, time.time()) avg_return = self.data_manager.storage.data['return'][:, 0].mean(dim=0) avg_return = (avg_return[0] + self.ngu_coef * avg_return[1]) / (1 + self.ngu_coef) logger.add_scalar('alg/return', avg_return, self.t, time.time()) # log value errors errors = [] for batch in self.data_manager.sampler(): errors.append(batch['vpred'] - batch['q_mc']) errors = torch.cat(errors) logger.add_scalar('alg/value_error_mean', errors.mean().cpu().numpy(), self.t, time.time()) logger.add_scalar('alg/value_error_std', errors.std().cpu().numpy(), self.t, time.time()) return self.t
def loss(self, batch): """Loss function.""" pi_out = self.pi(batch['obs'], reparameterization_trick=self.rsample) if self.discrete: new_ac = pi_out.action ent = pi_out.dist.entropy() else: assert isinstance(pi_out.dist, TanhNormal), ( "It is strongly encouraged that you use a TanhNormal " "action distribution for continuous action spaces.") if self.rsample: new_ac, new_pth_ac = pi_out.dist.rsample( return_pretanh_value=True) else: new_ac, new_pth_ac = pi_out.dist.sample( return_pretanh_value=True) logp = pi_out.dist.log_prob(new_ac, new_pth_ac) q1 = self.qf1(batch['obs'], batch['action']).value q2 = self.qf2(batch['obs'], batch['action']).value v = self.vf(batch['obs']).value # alpha loss if self.automatic_entropy_tuning: if self.discrete: ent_error = -ent + self.target_entropy else: ent_error = logp + self.target_entropy alpha_loss = -(self.log_alpha * ent_error.detach()).mean() self.opt_alpha.zero_grad() alpha_loss.backward() self.opt_alpha.step() alpha = self.log_alpha.exp() else: alpha = 1 alpha_loss = 0 # qf loss vtarg = self.target_vf(batch['next_obs']).value qtarg = self.reward_scale * batch['reward'].float() + ( (1.0 - batch['done']) * self.gamma * vtarg) assert qtarg.shape == q1.shape assert qtarg.shape == q2.shape qf1_loss = self.qf_criterion(q1, qtarg.detach()) qf2_loss = self.qf_criterion(q2, qtarg.detach()) # vf loss q1_outs = self.qf1(batch['obs'], new_ac) q1_new = q1_outs.value q2_new = self.qf2(batch['obs'], new_ac).value q = torch.min(q1_new, q2_new) if self.discrete: vtarg = q + alpha * ent else: vtarg = q - alpha * logp assert v.shape == vtarg.shape vf_loss = self.vf_criterion(v, vtarg.detach()) # pi loss pi_loss = None if self.t % self.policy_update_period == 0: if self.discrete: target_dist = CatDist(logits=q1_outs.qvals.detach()) pi_dist = CatDist(logits=alpha * pi_out.dist.logits) pi_loss = pi_dist.kl(target_dist).mean() else: if self.rsample: assert q.shape == logp.shape pi_loss = (alpha*logp - q1_new).mean() else: pi_targ = q1_new - v assert pi_targ.shape == logp.shape pi_loss = (logp * (alpha * logp - pi_targ).detach()).mean() pi_loss += self.policy_mean_reg_weight * ( pi_out.dist.normal.mean**2).mean() # log pi loss about as frequently as other losses if self.t % self.log_period < self.policy_update_period: logger.add_scalar('loss/pi', pi_loss, self.t, time.time()) if self.t % self.log_period < self.update_period: if self.automatic_entropy_tuning: logger.add_scalar('ent/log_alpha', self.log_alpha.detach().cpu().numpy(), self.t, time.time()) if self.discrete: scalars = {"target": self.target_entropy, "entropy": ent.mean().detach().cpu().numpy().item()} else: scalars = {"target": self.target_entropy, "entropy": -torch.mean( logp.detach()).cpu().numpy().item()} logger.add_scalars('ent/entropy', scalars, self.t, time.time()) else: if self.discrete: logger.add_scalar( 'ent/entropy', ent.mean().detach().cpu().numpy().item(), self.t, time.time()) else: logger.add_scalar( 'ent/entropy', -torch.mean(logp.detach()).cpu().numpy().item(), self.t, time.time()) logger.add_scalar('loss/qf1', qf1_loss, self.t, time.time()) logger.add_scalar('loss/qf2', qf2_loss, self.t, time.time()) logger.add_scalar('loss/vf', vf_loss, self.t, time.time()) return pi_loss, qf1_loss, qf2_loss, vf_loss