Ejemplo n.º 1
0
 def _act(self, obs):
     # TODO implement non-sampling validation mode
     obs = map2torch(self._obs_normalizer(obs), self._hp.device)
     if len(obs.shape) == 1:  # we need batched inputs for policy
         policy_output = self._remove_batch(self.policy(obs[None]))
         if 'dist' in policy_output:
             del policy_output['dist']
         return map2np(policy_output)
     return map2np(self.policy(obs))
Ejemplo n.º 2
0
    def run(self, inputs, use_learned_prior=True):
        """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan.
        :arg inputs: dict with 'states', 'actions', 'images' keys from environment
        :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior
        """
        if not self._action_plan:
            inputs = map2torch(inputs, device=self.device)

            # sample latent variable from prior
            z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \
                if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample()

            # decode into action plan
            z = z.repeat(
                self._hp.batch_size, 1
            )  # this is a HACK flat LSTM decoder can only take batch_size inputs
            input_obs = self._learned_prior_input(inputs).repeat(
                self._hp.batch_size, 1)
            actions = self.decode(z,
                                  cond_inputs=input_obs,
                                  steps=self._hp.n_rollout_steps)[0]
            self._action_plan = deque(split_along_axis(map2np(actions),
                                                       axis=0))

        return AttrDict(action=self._action_plan.popleft()[None])
Ejemplo n.º 3
0
 def _split_obs(self, obs):
     unflattened_obs = map2np(
         self._policy.unflatten_obs(
             map2torch(obs[:, :-self._policy.latent_dim],
                       device=self.device)))
     return AttrDict(
         cond_input=unflattened_obs.prior_obs,
         z=obs[:, -self._policy.latent_dim:],
     )
Ejemplo n.º 4
0
 def _act(self, obs):
     assert len(obs.shape) == 2 and obs.shape[
         0] == 1  # assume single-observation batches with leading 1-dim
     if not self.action_plan:
         # generate action plan if the current one is empty
         split_obs = self._split_obs(obs)
         with no_batchnorm_update(
                 self._policy
         ) if obs.shape[0] == 1 else contextlib.suppress():
             actions = self._policy.decode(
                 map2torch(split_obs.z, self._hp.device),
                 map2torch(split_obs.cond_input, self._hp.device),
                 self._policy.n_rollout_steps)
         self.action_plan = deque(split_along_axis(map2np(actions), axis=1))
     return AttrDict(action=self.action_plan.popleft())
Ejemplo n.º 5
0
 def _act_rand(self, obs):
     policy_output = self.policy.sample_rand(
         map2torch(obs, self.policy.device))
     if 'dist' in policy_output:
         del policy_output['dist']
     return map2np(policy_output)