def sample_rand(self, obs): with torch.no_grad(): with no_batchnorm_update(self.prior_net): prior_dist = self.prior_net.compute_learned_prior(obs, first_only=True).detach() action = prior_dist.sample() action, log_prob = self._tanh_squash_output(action, 0) # ignore log_prob output return AttrDict(action=action, log_prob=log_prob)
def _compute_prior_divergence(self, policy_output, obs): with no_batchnorm_update(self.prior_net): prior_dist = self.prior_net.compute_learned_prior( obs, first_only=True).detach() if self._hp.analytic_KL: return self._analytic_divergence(policy_output, prior_dist), prior_dist return self._mc_divergence(policy_output, prior_dist), prior_dist
def _act(self, obs): assert len(obs.shape) == 2 and obs.shape[ 0] == 1 # assume single-observation batches with leading 1-dim if not self.action_plan: # generate action plan if the current one is empty split_obs = self._split_obs(obs) with no_batchnorm_update( self._policy ) if obs.shape[0] == 1 else contextlib.suppress(): actions = self._policy.decode( map2torch(split_obs.z, self._hp.device), map2torch(split_obs.cond_input, self._hp.device), self._policy.n_rollout_steps) self.action_plan = deque(split_along_axis(map2np(actions), axis=1)) return AttrDict(action=self.action_plan.popleft())
def forward(self, obs): with no_batchnorm_update( self): # BN updates harm the initialized policy return super().forward(obs)
def forward(self, obs): with no_batchnorm_update(self): return LearnedPriorAugmentedPolicy.forward(self, obs)