Beispiel #1
0
    def _update_actor(self, batch_tensors):
        info = {}
        cur_obs, actions, advantages = dutil.get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            Postprocessing.ADVANTAGES,
        )
        advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                         1e-8)

        # Compute whitening matrices
        n_samples = self.config["logp_samples"]
        with self.optimizers["actor"].record_stats():
            _, log_prob = self.module.actor.sample(cur_obs, (n_samples, ))
            log_prob.mean().backward()

        # Compute surrogate loss
        with self.optimizers.optimize("actor"):
            surr_loss = -(self.module.actor.log_prob(cur_obs, actions) *
                          advantages).mean()
            info["loss(actor)"] = surr_loss.item()
            surr_loss.backward()
            pol_grad = [p.grad.clone() for p in self.module.actor.parameters()]

        if self.config["line_search"]:
            info.update(
                self._perform_line_search(pol_grad, surr_loss, batch_tensors))

        return info
Beispiel #2
0
    def _update_critic(self, batch_tensors):
        cur_obs, value_targets = dutil.get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            Postprocessing.VALUE_TARGETS,
        )
        mse = nn.MSELoss()
        fake_dist = Normal()
        fake_scale = torch.ones_like(value_targets)

        for _ in range(self.config["vf_iters"]):
            if isinstance(self.optimizers["critic"], KFACMixin):
                # Compute whitening matrices
                with self.optimizers["critic"].record_stats():
                    values = self.module.critic(cur_obs).squeeze(-1)
                    fake_samples = values + torch.randn_like(values)
                    log_prob = fake_dist.log_prob(fake_samples.detach(), {
                        "loc": values,
                        "scale": fake_scale
                    })
                    log_prob.mean().backward()

            with self.optimizers.optimize("critic"):
                mse_loss = mse(
                    self.module.critic(cur_obs).squeeze(-1), value_targets)
                mse_loss.backward()

        return {"loss(critic)": mse_loss.item()}
Beispiel #3
0
    def __call__(self, batch):
        """Compute loss for Q-value function."""
        # pylint:disable=too-many-arguments
        obs, actions, rewards, next_obs, dones = dutil.get_keys(
            batch,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            SampleBatch.REWARDS,
            SampleBatch.NEXT_OBS,
            SampleBatch.DONES,
        )
        with torch.no_grad():
            target_values = self.critic_targets(rewards, next_obs, dones)
        loss_fn = nn.MSELoss()
        values = torch.cat([m(obs, actions) for m in self.critics], dim=-1)
        critic_loss = loss_fn(values,
                              target_values.unsqueeze(-1).expand_as(values))

        stats = {
            "q_mean": values.mean().item(),
            "q_max": values.max().item(),
            "q_min": values.min().item(),
            "loss(critic)": critic_loss.item(),
        }
        return critic_loss, stats
Beispiel #4
0
    def extra_grad_info(self, batch_tensors):  # pylint:disable=unused-argument
        """Return statistics right after components are updated."""
        cur_obs, actions, old_logp, value_targets, value_preds = dutil.get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            SampleBatch.ACTION_LOGP,
            Postprocessing.VALUE_TARGETS,
            SampleBatch.VF_PREDS,
        )

        info = {
            "kl_divergence":
            torch.mean(old_logp -
                       self.module.actor.log_prob(cur_obs, actions)).item(),
            "entropy":
            torch.mean(-old_logp).item(),
            "perplexity":
            torch.mean(-old_logp).exp().item(),
            "explained_variance":
            explained_variance(value_targets.numpy(), value_preds.numpy()),
        }
        info.update({
            f"grad_norm({k})":
            nn.utils.clip_grad_norm_(self.module[k].parameters(),
                                     float("inf")).item()
            for k in ("actor", "critic")
        })
        return info
Beispiel #5
0
def plot_action_distributions(outputs, bins, ranges=()):
    acts, det = map(lambda x: x.numpy(), dutil.get_keys(outputs, "acts", "det"))
    data = {f"act[{i}]": acts[..., i] for i in range(acts.shape[-1])}
    dataset = pd.DataFrame(data)

    det_data = {f"det[{i}]": det[..., i] for i in range(det.shape[-1])}
    det_dataset = pd.DataFrame(det_data)

    st.bokeh_chart(make_histograms(dataset, bins, ranges=ranges))
    st.bokeh_chart(scatter_matrix(dataset, det_dataset))
Beispiel #6
0
 def sampled_one_step_state_values(self, batch):
     """Bootstrapped approximation of true state-value using sampled transition."""
     next_obs, rewards, dones = dutil.get_keys(
         batch, SampleBatch.NEXT_OBS, SampleBatch.REWARDS, SampleBatch.DONES,
     )
     return torch.where(
         dones,
         rewards,
         rewards + self.config["gamma"] * self.target_critic(next_obs).squeeze(-1),
     )
Beispiel #7
0
    def __call__(self, batch):
        """Compute loss for importance sampled fitted V iteration."""
        obs, is_ratios = dutil.get_keys(batch, SampleBatch.CUR_OBS, self.IS_RATIOS)

        values = self.critic(obs).squeeze(-1)
        with torch.no_grad():
            targets = self.sampled_one_step_state_values(batch)
        value_loss = torch.mean(
            is_ratios * torch.nn.MSELoss(reduction="none")(values, targets) / 2
        )
        return value_loss, {"loss(critic)": value_loss.item()}
Beispiel #8
0
    def __call__(self, batch: Dict[str,
                                   Tensor]) -> Tuple[Tensor, Dict[str, float]]:
        """Compute Maximum Likelihood Estimation (MLE) model loss.

        Returns:
            A tuple containg a 0d loss tensor and a dictionary of loss
            statistics
        """
        obs, actions, next_obs = get_keys(batch, *self.batch_keys)
        loss = -self.model_likelihood(obs, actions, next_obs).mean()
        return loss, {"loss(model)": loss.item()}
Beispiel #9
0
    def __call__(self, batch: TensorDict) -> Tuple[Tensor, TensorDict]:
        """Compute loss for Q-value function."""
        obs, actions, rewards, next_obs, dones = dutil.get_keys(batch, *self.batch_keys)
        with torch.no_grad():
            target_values = self.critic_targets(rewards, next_obs, dones)
        loss_fn = nn.MSELoss()
        values = self.critics(obs, actions)
        critic_loss = torch.stack([loss_fn(v, target_values) for v in values]).sum()

        stats = {"loss(critics)": critic_loss.item()}
        stats.update(self.q_value_info(values))
        return critic_loss, stats
Beispiel #10
0
    def unpack_batch(self, batch: TensorDict) -> Tuple[Tensor, ...]:
        """Returns the batch tensors corresponding to the batch keys.

        Tensors are returned in the same order `batch_keys` is defined.

        Args:
            batch: Dictionary of input tensors

        Returns:
            A tuple of tensors corresponding to each key in `batch_keys`
        """
        return tuple(get_keys(batch, *self.batch_keys))
Beispiel #11
0
    def __call__(self, batch: Dict[str,
                                   Tensor]) -> Tuple[Tensor, Dict[str, float]]:
        """Compute Maximum Likelihood Estimation (MLE) loss for each model.

        Returns:
            A tuple with a 1d loss tensor containing each model's loss and a
            dictionary of loss statistics
        """
        obs, actions, next_obs = get_keys(batch, *self.batch_keys)
        logps = self.model_likelihoods(obs, actions, next_obs)
        loss = -torch.stack(logps)
        info = {f"loss(models[{i}])": -l.item() for i, l in enumerate(logps)}
        return loss, info
Beispiel #12
0
    def __call__(self, batch: TensorDict) -> Tuple[Tensor, StatDict]:
        """Compute Maximum Likelihood Estimation (MLE) model loss.

        Returns:
            A tuple with a 1d loss tensor containing each model's loss and a
            dictionary of loss statistics
        """
        obs, act, new_obs = get_keys(batch, *self.batch_keys)
        nlls = self.loss_fns(obs, act, new_obs)

        losses = torch.stack(nlls)
        info = {f"{self.tag}(models[{i}])": n.item() for i, n in enumerate(nlls)}
        self._last_output = (losses, info)
        return losses.mean(), info
Beispiel #13
0
def test_compute_value_targets(policy_and_batch):
    policy, batch = policy_and_batch

    rewards, dones = get_keys(batch, SampleBatch.REWARDS, SampleBatch.DONES)
    targets = policy.loss_critic.sampled_one_step_state_values(batch)
    assert targets.shape == (10, )
    assert targets.dtype == torch.float32
    assert torch.allclose(targets[dones], rewards[dones])

    policy.module.zero_grad()
    targets.mean().backward()
    target_params = set(policy.module.target_critic.parameters())
    other_params = (p for p in policy.module.parameters()
                    if p not in target_params)
    assert all(p.grad is not None for p in target_params)
    assert all(p.grad is None for p in other_params)
Beispiel #14
0
def test_target_value(cdq_loss, batch, critics, target_critic):
    modules = nn.ModuleList([critics, target_critic])

    rewards, next_obs, dones = dutil.get_keys(batch, SampleBatch.REWARDS,
                                              SampleBatch.NEXT_OBS,
                                              SampleBatch.DONES)
    targets = cdq_loss.critic_targets(rewards, next_obs, dones)
    assert torch.is_tensor(targets)
    assert targets.shape == (len(next_obs), )
    assert targets.dtype == torch.float32
    assert torch.allclose(targets[dones], rewards[dones])

    modules.zero_grad()
    targets.mean().backward()
    target_params = set(target_critic.parameters())
    assert all(p.grad is not None for p in target_params)
    assert all(p.grad is None for p in set(critics.parameters()))
Beispiel #15
0
    def _perform_line_search(self, pol_grad, surr_loss, batch_tensors):
        # pylint:disable=too-many-locals
        kl_clip = self.optimizers["actor"].state["kl_clip"]
        expected_improvement = sum(
            (g * p.grad.data).sum()
            for g, p in zip(pol_grad, self.module.actor.parameters())
        ).item()

        cur_obs, actions, old_logp, advantages = dutil.get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            SampleBatch.ACTION_LOGP,
            Postprocessing.ADVANTAGES,
        )

        @torch.no_grad()
        def f_barrier(scale):
            for par in self.module.actor.parameters():
                par.data.add_(par.grad.data, alpha=scale)
            new_logp = self.module.actor.log_prob(cur_obs, actions)
            for par in self.module.actor.parameters():
                par.data.sub_(par.grad.data, alpha=scale)
            surr_loss = self._compute_surr_loss(old_logp, new_logp, advantages)
            avg_kl = torch.mean(old_logp - new_logp)
            return surr_loss.item() if avg_kl < kl_clip else np.inf

        scale, expected_improvement, improvement = line_search(
            f_barrier,
            1,
            1,
            expected_improvement,
            y_0=surr_loss.item(),
            **self.config["line_search_options"],
        )
        improvement_ratio = (
            improvement / expected_improvement if expected_improvement else np.nan
        )
        info = {
            "expected_improvement": expected_improvement,
            "actual_improvement": improvement,
            "improvement_ratio": improvement_ratio,
        }
        for par in self.module.actor.parameters():
            par.data.add_(par.grad.data, alpha=scale)
        return info
Beispiel #16
0
    def __call__(self, batch: Dict[str,
                                   Tensor]) -> Tuple[Tensor, Dict[str, float]]:
        """Compute bootstrapped Stochatic Value Gradient loss."""
        assert (self._reward_fn is not None
                ), "No reward function set. Did you call `set_reward_fn`?"

        obs, actions, next_obs, dones, is_ratios = get_keys(
            batch,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            SampleBatch.NEXT_OBS,
            SampleBatch.DONES,
            self.IS_RATIOS,
        )
        state_val = self.one_step_reproduced_state_value(
            obs, actions, next_obs, dones)
        svg_loss = -torch.mean(is_ratios * state_val)
        return svg_loss, {"loss(actor)": svg_loss.item()}
Beispiel #17
0
def test_target_value(policy_and_batch):
    policy, batch = policy_and_batch
    loss_fn = loss_maker(policy)

    rewards, next_obs, dones = dutil.get_keys(batch, SampleBatch.REWARDS,
                                              SampleBatch.NEXT_OBS,
                                              SampleBatch.DONES)
    targets = loss_fn.critic_targets(rewards, next_obs, dones)
    assert targets.shape == (len(next_obs), )
    assert targets.dtype == torch.float32
    assert torch.allclose(targets[dones], rewards[dones])

    policy.module.zero_grad()
    targets.mean().backward()
    target_params = set(policy.module.target_critics.parameters())
    target_params.update(set(policy.module.actor.parameters()))
    assert all(p.grad is not None for p in target_params)
    assert all(p.grad is None
               for p in set(policy.module.parameters()) - target_params)
Beispiel #18
0
    def sampled_one_step_state_values(self, batch):
        """Bootstrapped approximation of true state-value using sampled transition."""
        if self.ENTROPY in batch:
            entropy = batch[self.ENTROPY]
        else:
            with torch.no_grad():
                _, logp = self.actor(batch[SampleBatch.CUR_OBS])
                entropy = -logp

        next_obs, rewards, dones = dutil.get_keys(
            batch, SampleBatch.NEXT_OBS, SampleBatch.REWARDS, SampleBatch.DONES,
        )
        gamma = self.config["gamma"]
        augmented_rewards = rewards + self.alpha() * entropy
        return torch.where(
            dones,
            augmented_rewards,
            augmented_rewards + gamma * self.target_critic(next_obs).squeeze(-1),
        )
Beispiel #19
0
    def _update_critic(self, batch_tensors):
        info = {}
        mse = nn.MSELoss()

        cur_obs, value_targets, value_preds = get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            Postprocessing.VALUE_TARGETS,
            SampleBatch.VF_PREDS,
        )

        for _ in range(self.config["val_iters"]):
            with self.optimizers.optimize("critic"):
                loss = mse(self.module.critic(cur_obs).squeeze(-1), value_targets)
                loss.backward()

        info["vf_loss"] = loss.item()
        info["explained_variance"] = explained_variance(
            value_targets.numpy(), value_preds.numpy()
        )
        return info
Beispiel #20
0
def test_truncated_svg(policy_and_batch):
    policy, batch = policy_and_batch

    obs, actions, next_obs, rewards, dones = get_keys(
        batch,
        SampleBatch.CUR_OBS,
        SampleBatch.ACTIONS,
        SampleBatch.NEXT_OBS,
        SampleBatch.REWARDS,
        SampleBatch.DONES,
    )
    state_vals = policy.loss_actor.one_step_reproduced_state_value(
        obs, actions, next_obs, dones)
    assert state_vals.shape == (10, )
    assert state_vals.dtype == torch.float32
    assert torch.allclose(
        state_vals[dones],
        rewards[dones],
    )

    state_vals.mean().backward()
    assert all(p.grad is not None for p in policy.module.actor.parameters())
Beispiel #21
0
    def _update_actor(self, batch_tensors):
        info = {}
        cur_obs, actions, advantages = get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            Postprocessing.ADVANTAGES,
        )
        advantages = (advantages - advantages.mean()) / (advantages.std() +
                                                         1e-8)

        # Compute Policy Gradient
        surr_loss = -(self.module.actor.log_prob(cur_obs, actions) *
                      advantages).mean()
        pol_grad = flat_grad(surr_loss, self.module.actor.parameters())
        info["grad_norm(pg)"] = pol_grad.norm().item()

        # Compute Natural Gradient
        descent_step, cg_info = self._compute_descent_step(pol_grad, cur_obs)
        info["grad_norm(nat)"] = descent_step.norm().item()
        info.update(cg_info)

        # Perform Line Search
        if self.config["line_search"]:
            new_params, line_search_info = self._perform_line_search(
                pol_grad,
                descent_step,
                surr_loss,
                batch_tensors,
            )
            info.update(line_search_info)
        else:
            new_params = (
                parameters_to_vector(self.module.actor.parameters()) -
                descent_step)

        vector_to_parameters(new_params, self.module.actor.parameters())
        return info
Beispiel #22
0
    def _perform_line_search(self, pol_grad, descent_step, surr_loss,
                             batch_tensors):
        expected_improvement = pol_grad.dot(descent_step).item()

        cur_obs, actions, old_logp, advantages = get_keys(
            batch_tensors,
            SampleBatch.CUR_OBS,
            SampleBatch.ACTIONS,
            SampleBatch.ACTION_LOGP,
            Postprocessing.ADVANTAGES,
        )

        @torch.no_grad()
        def f_barrier(params):
            vector_to_parameters(params, self.module.actor.parameters())
            new_logp = self.module.actor.log_prob(cur_obs, actions)
            surr_loss = self._compute_surr_loss(old_logp, new_logp, advantages)
            avg_kl = torch.mean(old_logp - new_logp)
            return surr_loss.item(
            ) if avg_kl < self.config["delta"] else np.inf

        new_params, expected_improvement, improvement = line_search(
            f_barrier,
            parameters_to_vector(self.module.actor.parameters()),
            descent_step,
            expected_improvement,
            y_0=surr_loss.item(),
            **self.config["line_search_options"],
        )
        improvement_ratio = (improvement / expected_improvement
                             if expected_improvement else np.nan)
        info = {
            "expected_improvement": expected_improvement,
            "actual_improvement": improvement,
            "improvement_ratio": improvement_ratio,
        }
        return new_params, info
Beispiel #23
0
def test_critic_loss(policy_and_batch):
    policy, batch = policy_and_batch
    loss_fn = loss_maker(policy)

    loss, info = loss_fn(batch)
    assert loss.shape == ()
    assert loss.dtype == torch.float32
    assert isinstance(info, dict)

    params = set(policy.module.critics.parameters())
    loss.backward()
    assert all(p.grad is not None for p in params)
    assert all(p.grad is None
               for p in set(policy.module.parameters()) - params)

    obs, acts = dutil.get_keys(batch, SampleBatch.CUR_OBS, SampleBatch.ACTIONS)
    vals = [m(obs, acts) for m in policy.module.critics]
    concat_vals = torch.cat(vals, dim=-1)
    targets = torch.randn_like(vals[0])
    loss_fn = nn.MSELoss()
    assert torch.allclose(
        loss_fn(concat_vals, targets.expand_as(concat_vals)),
        sum(loss_fn(val, targets) for val in vals) / len(vals),
    )