def _act(self, obs): # TODO implement non-sampling validation mode obs = map2torch(self._obs_normalizer(obs), self._hp.device) if len(obs.shape) == 1: # we need batched inputs for policy policy_output = self._remove_batch(self.policy(obs[None])) if 'dist' in policy_output: del policy_output['dist'] return map2np(policy_output) return map2np(self.policy(obs))
def run(self, inputs, use_learned_prior=True): """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan. :arg inputs: dict with 'states', 'actions', 'images' keys from environment :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior """ if not self._action_plan: inputs = map2torch(inputs, device=self.device) # sample latent variable from prior z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \ if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample() # decode into action plan z = z.repeat( self._hp.batch_size, 1 ) # this is a HACK flat LSTM decoder can only take batch_size inputs input_obs = self._learned_prior_input(inputs).repeat( self._hp.batch_size, 1) actions = self.decode(z, cond_inputs=input_obs, steps=self._hp.n_rollout_steps)[0] self._action_plan = deque(split_along_axis(map2np(actions), axis=0)) return AttrDict(action=self._action_plan.popleft()[None])
def _split_obs(self, obs): unflattened_obs = map2np( self._policy.unflatten_obs( map2torch(obs[:, :-self._policy.latent_dim], device=self.device))) return AttrDict( cond_input=unflattened_obs.prior_obs, z=obs[:, -self._policy.latent_dim:], )
def _act(self, obs): assert len(obs.shape) == 2 and obs.shape[ 0] == 1 # assume single-observation batches with leading 1-dim if not self.action_plan: # generate action plan if the current one is empty split_obs = self._split_obs(obs) with no_batchnorm_update( self._policy ) if obs.shape[0] == 1 else contextlib.suppress(): actions = self._policy.decode( map2torch(split_obs.z, self._hp.device), map2torch(split_obs.cond_input, self._hp.device), self._policy.n_rollout_steps) self.action_plan = deque(split_along_axis(map2np(actions), axis=1)) return AttrDict(action=self.action_plan.popleft())
def _act_rand(self, obs): policy_output = self.policy.sample_rand( map2torch(obs, self.policy.device)) if 'dist' in policy_output: del policy_output['dist'] return map2np(policy_output)