def _fitting_diag(self, best_samples): """ Fit diagonal covariance gaussian and sampling from it Parameters ---------- best_samples : torch.Tensor shape (self.cem_batch_size, self.num_best_sampling, self.dim_ac) Returns ------- samples : torch.Tensor """ mean = torch.mean( best_samples, dim=1) # (self.cem_batch_size, self.dim_ac) # (self.cem_batch_size, self.dim_ac) std = torch.std(best_samples, dim=1) samples = Normal(loc=mean, scale=std).rsample( torch.Size((self.num_sampling,))) # (self.num_best_sampling, self.cem_batch_size, self.dim_ac) # (self.num_best_sampling, self.cem_batch_size, self.dim_ac) samples = samples.transpose(1, 0) samples = samples.reshape((self.num_sampling * self.cem_batch_size, self.dim_ac)) # (self.num_best_sampling * self.cem_batch_size, self.dim_ac) # (self.num_best_sampling * self.cem_batch_size, self.dim_ac) samples = self._clamp(samples) return samples
def sample_from_hyperpolicy(self): # Sample from current hyperpolicy to update current policy flattened_params = Normal(self.policy_params[self.current_policy_index], self.policy_std).sample() # shape: (state_dim*action_dim) self.sampled_params = flattened_params.reshape(self.state_dim, self.action_dim)