Ejemplo n.º 1
0
    def _fitting_diag(self, best_samples):
        """
        Fit diagonal covariance gaussian and sampling from it

        Parameters
        ----------
        best_samples : torch.Tensor
            shape (self.cem_batch_size, self.num_best_sampling, self.dim_ac)

        Returns
        -------
        samples : torch.Tensor
        """
        mean = torch.mean(
            best_samples, dim=1)  # (self.cem_batch_size, self.dim_ac)
        # (self.cem_batch_size, self.dim_ac)
        std = torch.std(best_samples, dim=1)
        samples = Normal(loc=mean, scale=std).rsample(
            torch.Size((self.num_sampling,)))  # (self.num_best_sampling, self.cem_batch_size, self.dim_ac)
        # (self.num_best_sampling, self.cem_batch_size, self.dim_ac)
        samples = samples.transpose(1, 0)
        samples = samples.reshape((self.num_sampling * self.cem_batch_size,
                                   self.dim_ac))  # (self.num_best_sampling * self.cem_batch_size,  self.dim_ac)
        # (self.num_best_sampling * self.cem_batch_size,  self.dim_ac)
        samples = self._clamp(samples)
        return samples
Ejemplo n.º 2
0
 def sample_from_hyperpolicy(self):
     # Sample from current hyperpolicy to update current policy
     flattened_params = Normal(self.policy_params[self.current_policy_index], self.policy_std).sample()  # shape: (state_dim*action_dim)
     self.sampled_params = flattened_params.reshape(self.state_dim, self.action_dim)