def _update_actor(self, batch_tensors): info = {} cur_obs, actions, advantages = dutil.get_keys( batch_tensors, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, Postprocessing.ADVANTAGES, ) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Compute whitening matrices n_samples = self.config["logp_samples"] with self.optimizers["actor"].record_stats(): _, log_prob = self.module.actor.sample(cur_obs, (n_samples, )) log_prob.mean().backward() # Compute surrogate loss with self.optimizers.optimize("actor"): surr_loss = -(self.module.actor.log_prob(cur_obs, actions) * advantages).mean() info["loss(actor)"] = surr_loss.item() surr_loss.backward() pol_grad = [p.grad.clone() for p in self.module.actor.parameters()] if self.config["line_search"]: info.update( self._perform_line_search(pol_grad, surr_loss, batch_tensors)) return info
def _update_critic(self, batch_tensors): cur_obs, value_targets = dutil.get_keys( batch_tensors, SampleBatch.CUR_OBS, Postprocessing.VALUE_TARGETS, ) mse = nn.MSELoss() fake_dist = Normal() fake_scale = torch.ones_like(value_targets) for _ in range(self.config["vf_iters"]): if isinstance(self.optimizers["critic"], KFACMixin): # Compute whitening matrices with self.optimizers["critic"].record_stats(): values = self.module.critic(cur_obs).squeeze(-1) fake_samples = values + torch.randn_like(values) log_prob = fake_dist.log_prob(fake_samples.detach(), { "loc": values, "scale": fake_scale }) log_prob.mean().backward() with self.optimizers.optimize("critic"): mse_loss = mse( self.module.critic(cur_obs).squeeze(-1), value_targets) mse_loss.backward() return {"loss(critic)": mse_loss.item()}
def __call__(self, batch): """Compute loss for Q-value function.""" # pylint:disable=too-many-arguments obs, actions, rewards, next_obs, dones = dutil.get_keys( batch, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.REWARDS, SampleBatch.NEXT_OBS, SampleBatch.DONES, ) with torch.no_grad(): target_values = self.critic_targets(rewards, next_obs, dones) loss_fn = nn.MSELoss() values = torch.cat([m(obs, actions) for m in self.critics], dim=-1) critic_loss = loss_fn(values, target_values.unsqueeze(-1).expand_as(values)) stats = { "q_mean": values.mean().item(), "q_max": values.max().item(), "q_min": values.min().item(), "loss(critic)": critic_loss.item(), } return critic_loss, stats
def extra_grad_info(self, batch_tensors): # pylint:disable=unused-argument """Return statistics right after components are updated.""" cur_obs, actions, old_logp, value_targets, value_preds = dutil.get_keys( batch_tensors, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.ACTION_LOGP, Postprocessing.VALUE_TARGETS, SampleBatch.VF_PREDS, ) info = { "kl_divergence": torch.mean(old_logp - self.module.actor.log_prob(cur_obs, actions)).item(), "entropy": torch.mean(-old_logp).item(), "perplexity": torch.mean(-old_logp).exp().item(), "explained_variance": explained_variance(value_targets.numpy(), value_preds.numpy()), } info.update({ f"grad_norm({k})": nn.utils.clip_grad_norm_(self.module[k].parameters(), float("inf")).item() for k in ("actor", "critic") }) return info
def plot_action_distributions(outputs, bins, ranges=()): acts, det = map(lambda x: x.numpy(), dutil.get_keys(outputs, "acts", "det")) data = {f"act[{i}]": acts[..., i] for i in range(acts.shape[-1])} dataset = pd.DataFrame(data) det_data = {f"det[{i}]": det[..., i] for i in range(det.shape[-1])} det_dataset = pd.DataFrame(det_data) st.bokeh_chart(make_histograms(dataset, bins, ranges=ranges)) st.bokeh_chart(scatter_matrix(dataset, det_dataset))
def sampled_one_step_state_values(self, batch): """Bootstrapped approximation of true state-value using sampled transition.""" next_obs, rewards, dones = dutil.get_keys( batch, SampleBatch.NEXT_OBS, SampleBatch.REWARDS, SampleBatch.DONES, ) return torch.where( dones, rewards, rewards + self.config["gamma"] * self.target_critic(next_obs).squeeze(-1), )
def __call__(self, batch): """Compute loss for importance sampled fitted V iteration.""" obs, is_ratios = dutil.get_keys(batch, SampleBatch.CUR_OBS, self.IS_RATIOS) values = self.critic(obs).squeeze(-1) with torch.no_grad(): targets = self.sampled_one_step_state_values(batch) value_loss = torch.mean( is_ratios * torch.nn.MSELoss(reduction="none")(values, targets) / 2 ) return value_loss, {"loss(critic)": value_loss.item()}
def __call__(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, float]]: """Compute Maximum Likelihood Estimation (MLE) model loss. Returns: A tuple containg a 0d loss tensor and a dictionary of loss statistics """ obs, actions, next_obs = get_keys(batch, *self.batch_keys) loss = -self.model_likelihood(obs, actions, next_obs).mean() return loss, {"loss(model)": loss.item()}
def __call__(self, batch: TensorDict) -> Tuple[Tensor, TensorDict]: """Compute loss for Q-value function.""" obs, actions, rewards, next_obs, dones = dutil.get_keys(batch, *self.batch_keys) with torch.no_grad(): target_values = self.critic_targets(rewards, next_obs, dones) loss_fn = nn.MSELoss() values = self.critics(obs, actions) critic_loss = torch.stack([loss_fn(v, target_values) for v in values]).sum() stats = {"loss(critics)": critic_loss.item()} stats.update(self.q_value_info(values)) return critic_loss, stats
def unpack_batch(self, batch: TensorDict) -> Tuple[Tensor, ...]: """Returns the batch tensors corresponding to the batch keys. Tensors are returned in the same order `batch_keys` is defined. Args: batch: Dictionary of input tensors Returns: A tuple of tensors corresponding to each key in `batch_keys` """ return tuple(get_keys(batch, *self.batch_keys))
def __call__(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, float]]: """Compute Maximum Likelihood Estimation (MLE) loss for each model. Returns: A tuple with a 1d loss tensor containing each model's loss and a dictionary of loss statistics """ obs, actions, next_obs = get_keys(batch, *self.batch_keys) logps = self.model_likelihoods(obs, actions, next_obs) loss = -torch.stack(logps) info = {f"loss(models[{i}])": -l.item() for i, l in enumerate(logps)} return loss, info
def __call__(self, batch: TensorDict) -> Tuple[Tensor, StatDict]: """Compute Maximum Likelihood Estimation (MLE) model loss. Returns: A tuple with a 1d loss tensor containing each model's loss and a dictionary of loss statistics """ obs, act, new_obs = get_keys(batch, *self.batch_keys) nlls = self.loss_fns(obs, act, new_obs) losses = torch.stack(nlls) info = {f"{self.tag}(models[{i}])": n.item() for i, n in enumerate(nlls)} self._last_output = (losses, info) return losses.mean(), info
def test_compute_value_targets(policy_and_batch): policy, batch = policy_and_batch rewards, dones = get_keys(batch, SampleBatch.REWARDS, SampleBatch.DONES) targets = policy.loss_critic.sampled_one_step_state_values(batch) assert targets.shape == (10, ) assert targets.dtype == torch.float32 assert torch.allclose(targets[dones], rewards[dones]) policy.module.zero_grad() targets.mean().backward() target_params = set(policy.module.target_critic.parameters()) other_params = (p for p in policy.module.parameters() if p not in target_params) assert all(p.grad is not None for p in target_params) assert all(p.grad is None for p in other_params)
def test_target_value(cdq_loss, batch, critics, target_critic): modules = nn.ModuleList([critics, target_critic]) rewards, next_obs, dones = dutil.get_keys(batch, SampleBatch.REWARDS, SampleBatch.NEXT_OBS, SampleBatch.DONES) targets = cdq_loss.critic_targets(rewards, next_obs, dones) assert torch.is_tensor(targets) assert targets.shape == (len(next_obs), ) assert targets.dtype == torch.float32 assert torch.allclose(targets[dones], rewards[dones]) modules.zero_grad() targets.mean().backward() target_params = set(target_critic.parameters()) assert all(p.grad is not None for p in target_params) assert all(p.grad is None for p in set(critics.parameters()))
def _perform_line_search(self, pol_grad, surr_loss, batch_tensors): # pylint:disable=too-many-locals kl_clip = self.optimizers["actor"].state["kl_clip"] expected_improvement = sum( (g * p.grad.data).sum() for g, p in zip(pol_grad, self.module.actor.parameters()) ).item() cur_obs, actions, old_logp, advantages = dutil.get_keys( batch_tensors, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.ACTION_LOGP, Postprocessing.ADVANTAGES, ) @torch.no_grad() def f_barrier(scale): for par in self.module.actor.parameters(): par.data.add_(par.grad.data, alpha=scale) new_logp = self.module.actor.log_prob(cur_obs, actions) for par in self.module.actor.parameters(): par.data.sub_(par.grad.data, alpha=scale) surr_loss = self._compute_surr_loss(old_logp, new_logp, advantages) avg_kl = torch.mean(old_logp - new_logp) return surr_loss.item() if avg_kl < kl_clip else np.inf scale, expected_improvement, improvement = line_search( f_barrier, 1, 1, expected_improvement, y_0=surr_loss.item(), **self.config["line_search_options"], ) improvement_ratio = ( improvement / expected_improvement if expected_improvement else np.nan ) info = { "expected_improvement": expected_improvement, "actual_improvement": improvement, "improvement_ratio": improvement_ratio, } for par in self.module.actor.parameters(): par.data.add_(par.grad.data, alpha=scale) return info
def __call__(self, batch: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, float]]: """Compute bootstrapped Stochatic Value Gradient loss.""" assert (self._reward_fn is not None ), "No reward function set. Did you call `set_reward_fn`?" obs, actions, next_obs, dones, is_ratios = get_keys( batch, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS, SampleBatch.DONES, self.IS_RATIOS, ) state_val = self.one_step_reproduced_state_value( obs, actions, next_obs, dones) svg_loss = -torch.mean(is_ratios * state_val) return svg_loss, {"loss(actor)": svg_loss.item()}
def test_target_value(policy_and_batch): policy, batch = policy_and_batch loss_fn = loss_maker(policy) rewards, next_obs, dones = dutil.get_keys(batch, SampleBatch.REWARDS, SampleBatch.NEXT_OBS, SampleBatch.DONES) targets = loss_fn.critic_targets(rewards, next_obs, dones) assert targets.shape == (len(next_obs), ) assert targets.dtype == torch.float32 assert torch.allclose(targets[dones], rewards[dones]) policy.module.zero_grad() targets.mean().backward() target_params = set(policy.module.target_critics.parameters()) target_params.update(set(policy.module.actor.parameters())) assert all(p.grad is not None for p in target_params) assert all(p.grad is None for p in set(policy.module.parameters()) - target_params)
def sampled_one_step_state_values(self, batch): """Bootstrapped approximation of true state-value using sampled transition.""" if self.ENTROPY in batch: entropy = batch[self.ENTROPY] else: with torch.no_grad(): _, logp = self.actor(batch[SampleBatch.CUR_OBS]) entropy = -logp next_obs, rewards, dones = dutil.get_keys( batch, SampleBatch.NEXT_OBS, SampleBatch.REWARDS, SampleBatch.DONES, ) gamma = self.config["gamma"] augmented_rewards = rewards + self.alpha() * entropy return torch.where( dones, augmented_rewards, augmented_rewards + gamma * self.target_critic(next_obs).squeeze(-1), )
def _update_critic(self, batch_tensors): info = {} mse = nn.MSELoss() cur_obs, value_targets, value_preds = get_keys( batch_tensors, SampleBatch.CUR_OBS, Postprocessing.VALUE_TARGETS, SampleBatch.VF_PREDS, ) for _ in range(self.config["val_iters"]): with self.optimizers.optimize("critic"): loss = mse(self.module.critic(cur_obs).squeeze(-1), value_targets) loss.backward() info["vf_loss"] = loss.item() info["explained_variance"] = explained_variance( value_targets.numpy(), value_preds.numpy() ) return info
def test_truncated_svg(policy_and_batch): policy, batch = policy_and_batch obs, actions, next_obs, rewards, dones = get_keys( batch, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS, SampleBatch.REWARDS, SampleBatch.DONES, ) state_vals = policy.loss_actor.one_step_reproduced_state_value( obs, actions, next_obs, dones) assert state_vals.shape == (10, ) assert state_vals.dtype == torch.float32 assert torch.allclose( state_vals[dones], rewards[dones], ) state_vals.mean().backward() assert all(p.grad is not None for p in policy.module.actor.parameters())
def _update_actor(self, batch_tensors): info = {} cur_obs, actions, advantages = get_keys( batch_tensors, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, Postprocessing.ADVANTAGES, ) advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Compute Policy Gradient surr_loss = -(self.module.actor.log_prob(cur_obs, actions) * advantages).mean() pol_grad = flat_grad(surr_loss, self.module.actor.parameters()) info["grad_norm(pg)"] = pol_grad.norm().item() # Compute Natural Gradient descent_step, cg_info = self._compute_descent_step(pol_grad, cur_obs) info["grad_norm(nat)"] = descent_step.norm().item() info.update(cg_info) # Perform Line Search if self.config["line_search"]: new_params, line_search_info = self._perform_line_search( pol_grad, descent_step, surr_loss, batch_tensors, ) info.update(line_search_info) else: new_params = ( parameters_to_vector(self.module.actor.parameters()) - descent_step) vector_to_parameters(new_params, self.module.actor.parameters()) return info
def _perform_line_search(self, pol_grad, descent_step, surr_loss, batch_tensors): expected_improvement = pol_grad.dot(descent_step).item() cur_obs, actions, old_logp, advantages = get_keys( batch_tensors, SampleBatch.CUR_OBS, SampleBatch.ACTIONS, SampleBatch.ACTION_LOGP, Postprocessing.ADVANTAGES, ) @torch.no_grad() def f_barrier(params): vector_to_parameters(params, self.module.actor.parameters()) new_logp = self.module.actor.log_prob(cur_obs, actions) surr_loss = self._compute_surr_loss(old_logp, new_logp, advantages) avg_kl = torch.mean(old_logp - new_logp) return surr_loss.item( ) if avg_kl < self.config["delta"] else np.inf new_params, expected_improvement, improvement = line_search( f_barrier, parameters_to_vector(self.module.actor.parameters()), descent_step, expected_improvement, y_0=surr_loss.item(), **self.config["line_search_options"], ) improvement_ratio = (improvement / expected_improvement if expected_improvement else np.nan) info = { "expected_improvement": expected_improvement, "actual_improvement": improvement, "improvement_ratio": improvement_ratio, } return new_params, info
def test_critic_loss(policy_and_batch): policy, batch = policy_and_batch loss_fn = loss_maker(policy) loss, info = loss_fn(batch) assert loss.shape == () assert loss.dtype == torch.float32 assert isinstance(info, dict) params = set(policy.module.critics.parameters()) loss.backward() assert all(p.grad is not None for p in params) assert all(p.grad is None for p in set(policy.module.parameters()) - params) obs, acts = dutil.get_keys(batch, SampleBatch.CUR_OBS, SampleBatch.ACTIONS) vals = [m(obs, acts) for m in policy.module.critics] concat_vals = torch.cat(vals, dim=-1) targets = torch.randn_like(vals[0]) loss_fn = nn.MSELoss() assert torch.allclose( loss_fn(concat_vals, targets.expand_as(concat_vals)), sum(loss_fn(val, targets) for val in vals) / len(vals), )