Пример #1
0
    def forward(self, state, action=torch.tensor(float("nan"))):
        """Get value at a given state-(action) through simulation.

        Parameters
        ----------
        state: Tensor.
            State where to evaluate the value.
        action: Tensor, optional.
            First action of simulation.
        """
        unfreeze_parameters(self.policy)
        with DisableGradient(self.sim.dynamical_model, self.sim.reward_model,
                             self.sim.termination_model):
            sim_observation = self.sim.simulate(state, self.policy, action)

        if isinstance(self.value_function, IntegrateQValueFunction):
            cm = DisableGradient(self.value_function.q_function)
        else:
            cm = DisableGradient(self.value_function)
        with cm:
            v = mc_return(
                sim_observation,
                gamma=self.gamma,
                lambda_=self.lambda_,
                value_function=self.value_function,
                reward_transformer=self.reward_transformer,
                entropy_regularization=self.entropy_regularization,
                reduction="none",
            )

        v = v.reshape(self.sim.num_samples, state.shape[0], -1).mean(0)
        v = v[:, 0]  # In cases of ensembles return first component.
        return v
Пример #2
0
    def simulate_and_learn_policy(self):
        """Simulate the model and optimize the policy with the learned data.

        This consists of two steps:
            Step 1: Simulate trajectories with the model.
                Calls self.simulate_model().
            Step 2: Implement a model free RL method that optimizes the policy.
                Calls self.learn_policy(). To be implemented by a Base Class.
        """
        print(colorize("Optimizing Policy with Model Data", "yellow"))
        self.dynamical_model.eval()
        self.sim_dataset.reset()  # Erase simulation data set before starting.
        with DisableGradient(
                self.dynamical_model), gpytorch.settings.fast_pred_var():
            for i in tqdm(range(self.policy_opt_num_iter)):
                # Step 1: Compute the state distribution
                with torch.no_grad():
                    self.simulate_model()

                # Log last simulations.
                self._log_simulated_trajectory()

                # Step 2: Optimize policy
                self.learn_policy()

                if (self.sim_refresh_interval > 0
                        and (i + 1) % self.sim_refresh_interval == 0):
                    self.sim_dataset.reset()
Пример #3
0
    def actor_loss(self, observation):
        """Use the model to compute the gradient loss."""
        state, action = observation.state, observation.action
        next_state, done = observation.next_state, observation.done

        # Infer eta.
        action_mean, action_chol = self.policy(state)
        with torch.no_grad():
            eta = torch.inverse(action_chol) @ (
                (action - action_mean).unsqueeze(-1))

        # Compute entropy and log_probability.
        pi = tensor_to_distribution((action_mean, action_chol))
        _, log_p = get_entropy_and_log_p(pi, action, self.policy.action_scale)

        # Compute off-policy weight.
        with torch.no_grad():
            weight = self.get_ope_weight(state, action,
                                         observation.log_prob_action)

        with DisableGradient(
                self.dynamical_model,
                self.reward_model,
                self.termination_model,
                self.critic_target,
        ):
            # Compute re-parameterized policy sample.
            action = (action_mean + (action_chol @ eta).squeeze(-1)).clamp(
                -1, 1)

            # Infer xi.
            ns_mean, ns_chol = self.dynamical_model(state, action)
            with torch.no_grad():
                xi = torch.inverse(ns_chol) @ (
                    (next_state - ns_mean).unsqueeze(-1))

            # Compute re-parameterized next-state sample.
            ns = ns_mean + (ns_chol @ xi).squeeze(-1)

            # Compute reward.
            r = tensor_to_distribution(self.reward_model(state, action,
                                                         ns)).rsample()
            r = r[..., 0]

            next_v = self.value_function(ns)
            if isinstance(self.critic, NNEnsembleValueFunction) or isinstance(
                    self.critic, NNEnsembleQFunction):
                next_v = next_v[..., 0]

            v = r + self.gamma * next_v * (1 - done)

        return Loss(policy_loss=-(weight * v)).reduce(self.criterion.reduction)
Пример #4
0
    def simulate(self,
                 state,
                 policy,
                 initial_action=None,
                 logger=None,
                 stack_obs=False):
        """Simulate from initial_states."""
        self.dynamical_model.eval()
        with DisableGradient(
                self.dynamical_model, self.reward_model,
                self.termination_model), gpytorch.settings.fast_pred_var():
            trajectory = super().simulate(state, policy, stack_obs=stack_obs)

        self._log_trajectory(trajectory)
        return trajectory
Пример #5
0
    def actor_loss(self, observation) -> Loss:
        """Compute Actor loss."""
        state, action = observation.state[..., 0, :], observation.action[...,
                                                                         0, :]
        action_mean, action_chol = self.policy(state)

        # Infer eta.
        with torch.no_grad():
            delta = action / self.policy.action_scale - action_mean
            eta = torch.inverse(action_chol) @ delta.unsqueeze(-1)

        # Compute re-parameterized policy sample.
        action = self.policy.action_scale * (
            action_mean + (action_chol @ eta).squeeze(-1)).clamp(-1.0, 1.0)

        # Propagate gradient.
        with DisableGradient(self.critic):
            q = self.critic(observation.state[..., 0, :], action)
            if isinstance(self.critic, NNEnsembleQFunction):
                q = q[..., 0]

        return Loss(policy_loss=-q).reduce(self.criterion.reduction)
Пример #6
0
    def learn(self, memory=None):
        """Learn a policy with the model."""

        #

        def closure():
            """Gradient calculation."""
            if memory is None:
                observation, *_ = self.memory.sample_batch(self.batch_size)
            else:
                observation, *_ = memory.sample_batch(self.batch_size)
            self.optimizer.zero_grad()
            losses = self.algorithm(observation.clone())
            losses.combined_loss.mean().backward()

            torch.nn.utils.clip_grad_norm_(self.algorithm.parameters(),
                                           self.clip_gradient_val)
            return losses

        with DisableGradient(self.dynamical_model, self.reward_model,
                             self.termination_model):
            self._learn_steps(closure)
Пример #7
0
    def _learn_steps(self, closure):
        """Apply `num_iter' learn steps to closure function."""
        for _ in tqdm(range(self.num_iter), disable=not self._training_verbose):
            if self.train_steps % self.policy_update_frequency == 0:
                cm = contextlib.nullcontext()
            else:
                cm = DisableGradient(self.policy)

            with cm:
                losses = self.optimizer.step(closure=closure)  # type: Loss

            self.logger.update(**asdict(average_dataclass(losses)))
            self.logger.update(**self.algorithm.info())

            self.counters["train_steps"] += 1
            if self.train_steps % self.target_update_frequency == 0:
                self.algorithm.update()
                for param in self.params.values():
                    param.update()

            if self.early_stop(losses, **self.algorithm.info()):
                break
        self.algorithm.reset()
        self.early_stopping_algorithm.reset()
Пример #8
0
    def learn_policy(self):
        """Optimize the policy."""
        # Iterate over state batches in the state distribution
        self.algorithm.reset()
        for _ in range(self.policy_opt_gradient_steps):

            def closure():
                """Gradient calculation."""
                states = Observation(state=self.sim_dataset.sample_batch(
                    self.policy_opt_batch_size))
                self.optimizer.zero_grad()
                losses = self.algorithm(states)
                losses.combined_loss.backward()
                return losses

            if self.train_steps % self.policy_update_frequency == 0:
                cm = contextlib.nullcontext()
            else:
                cm = DisableGradient(self.policy)

            with cm:
                losses = self.optimizer.step(closure=closure)

            self.logger.update(**asdict(average_dataclass(losses)))
            self.logger.update(**self.algorithm.info())

            self.counters["train_steps"] += 1
            if self.train_steps % self.policy_opt_target_update_frequency == 0:
                self.algorithm.update()
                for param in self.params.values():
                    param.update()

            if self.early_stop(losses, **self.algorithm.info()):
                break

        self.algorithm.reset()