Esempio n. 1
0
 def sample_rand(self, obs):
     with torch.no_grad():
         with no_batchnorm_update(self.prior_net):
             prior_dist = self.prior_net.compute_learned_prior(obs, first_only=True).detach()
     action = prior_dist.sample()
     action, log_prob = self._tanh_squash_output(action, 0)        # ignore log_prob output
     return AttrDict(action=action, log_prob=log_prob)
Esempio n. 2
0
 def _compute_prior_divergence(self, policy_output, obs):
     with no_batchnorm_update(self.prior_net):
         prior_dist = self.prior_net.compute_learned_prior(
             obs, first_only=True).detach()
         if self._hp.analytic_KL:
             return self._analytic_divergence(policy_output,
                                              prior_dist), prior_dist
         return self._mc_divergence(policy_output, prior_dist), prior_dist
Esempio n. 3
0
 def _act(self, obs):
     assert len(obs.shape) == 2 and obs.shape[
         0] == 1  # assume single-observation batches with leading 1-dim
     if not self.action_plan:
         # generate action plan if the current one is empty
         split_obs = self._split_obs(obs)
         with no_batchnorm_update(
                 self._policy
         ) if obs.shape[0] == 1 else contextlib.suppress():
             actions = self._policy.decode(
                 map2torch(split_obs.z, self._hp.device),
                 map2torch(split_obs.cond_input, self._hp.device),
                 self._policy.n_rollout_steps)
         self.action_plan = deque(split_along_axis(map2np(actions), axis=1))
     return AttrDict(action=self.action_plan.popleft())
Esempio n. 4
0
 def forward(self, obs):
     with no_batchnorm_update(
             self):  # BN updates harm the initialized policy
         return super().forward(obs)
Esempio n. 5
0
 def forward(self, obs):
     with no_batchnorm_update(self):
         return LearnedPriorAugmentedPolicy.forward(self, obs)