def _act(self, obs): assert len(obs.shape) == 2 and obs.shape[ 0] == 1 # assume single-observation batches with leading 1-dim if not self.action_plan: # generate action plan if the current one is empty split_obs = self._split_obs(obs) with no_batchnorm_update( self._policy ) if obs.shape[0] == 1 else contextlib.suppress(): actions = self._policy.decode( map2torch(split_obs.z, self._hp.device), map2torch(split_obs.cond_input, self._hp.device), self._policy.n_rollout_steps) self.action_plan = deque(split_along_axis(map2np(actions), axis=1)) return AttrDict(action=self.action_plan.popleft())
def run(self, inputs, use_learned_prior=True): """Policy interface for model. Runs decoder if action plan is empty, otherwise returns next action from action plan. :arg inputs: dict with 'states', 'actions', 'images' keys from environment :arg use_learned_prior: if True, uses learned prior otherwise samples latent from uniform prior """ if not self._action_plan: inputs = map2torch(inputs, device=self.device) # sample latent variable from prior z = self.compute_learned_prior(self._learned_prior_input(inputs), first_only=True).sample() \ if use_learned_prior else Gaussian(torch.zeros((1, self._hp.nz_vae*2), device=self.device)).sample() # decode into action plan z = z.repeat( self._hp.batch_size, 1 ) # this is a HACK flat LSTM decoder can only take batch_size inputs input_obs = self._learned_prior_input(inputs).repeat( self._hp.batch_size, 1) actions = self.decode(z, cond_inputs=input_obs, steps=self._hp.n_rollout_steps)[0] self._action_plan = deque(split_along_axis(map2np(actions), axis=0)) return AttrDict(action=self._action_plan.popleft()[None])
def _act(self, obs): # TODO implement non-sampling validation mode obs = map2torch(self._obs_normalizer(obs), self._hp.device) if len(obs.shape) == 1: # we need batched inputs for policy policy_output = self._remove_batch(self.policy(obs[None])) if 'dist' in policy_output: del policy_output['dist'] return map2np(policy_output) return map2np(self.policy(obs))
def _split_obs(self, obs): unflattened_obs = map2np( self._policy.unflatten_obs( map2torch(obs[:, :-self._policy.latent_dim], device=self.device))) return AttrDict( cond_input=unflattened_obs.prior_obs, z=obs[:, -self._policy.latent_dim:], )
def update(self, experience_batch): if 'delay' in self._hp.omega_schedule_params and self._update_steps < self._hp.omega_schedule_params.delay: # if schedule has warmup phase in which *only* prior is sampled, train policy to minimize divergence self.replay_buffer.append(experience_batch) experience_batch = self.replay_buffer.sample(n_samples=self._hp.batch_size) experience_batch = map2torch(experience_batch, self._hp.device) policy_output = self._run_policy(experience_batch.observation) policy_loss = policy_output.prior_divergence.mean() self._perform_update(policy_loss, self.policy_opt, self.policy) self._update_steps += 1 info = AttrDict(prior_divergence=policy_output.prior_divergence.mean()) else: info = super().update(experience_batch) info.omega = self._omega(self._update_steps) return info
def _act_rand(self, obs): policy_output = self.policy.sample_rand( map2torch(obs, self.policy.device)) if 'dist' in policy_output: del policy_output['dist'] return map2np(policy_output)
def update(self, experience_batch): """Updates actor and critics.""" # push experience batch into replay buffer self.add_experience(experience_batch) for _ in range(self._hp.update_iterations): # sample batch and normalize experience_batch = self._sample_experience() experience_batch = self._normalize_batch(experience_batch) experience_batch = map2torch(experience_batch, self._hp.device) experience_batch = self._preprocess_experience(experience_batch) policy_output = self._run_policy(experience_batch.observation) # update alpha alpha_loss = self._update_alpha(experience_batch, policy_output) # compute policy loss policy_loss = self._compute_policy_loss(experience_batch, policy_output) # compute target Q value with torch.no_grad(): policy_output_next = self._run_policy( experience_batch.observation_next) value_next = self._compute_next_value(experience_batch, policy_output_next) q_target = experience_batch.reward * self._hp.reward_scale + \ (1 - experience_batch.done) * self._hp.discount_factor * value_next if self._hp.clip_q_target: q_target = self._clip_q_target(q_target) q_target = q_target.detach() check_shape(q_target, [self._hp.batch_size]) # compute critic loss critic_losses, qs = self._compute_critic_loss( experience_batch, q_target) # update critic networks [ self._perform_update(critic_loss, critic_opt, critic) for critic_loss, critic_opt, critic in zip( critic_losses, self.critic_opts, self.critics) ] # update target networks [ self._soft_update_target_network(critic_target, critic) for critic_target, critic in zip(self.critic_targets, self.critics) ] # update policy network on policy loss self._perform_update(policy_loss, self.policy_opt, self.policy) # logging info = AttrDict( # losses policy_loss=policy_loss, alpha_loss=alpha_loss, critic_loss_1=critic_losses[0], critic_loss_2=critic_losses[1], ) if self._update_steps % 100 == 0: info.update( AttrDict( # gradient norms policy_grad_norm=avg_grad_norm(self.policy), critic_1_grad_norm=avg_grad_norm(self.critics[0]), critic_2_grad_norm=avg_grad_norm(self.critics[1]), )) info.update( AttrDict( # misc alpha=self.alpha, pi_log_prob=policy_output.log_prob.mean(), policy_entropy=policy_output.dist.entropy().mean(), q_target=q_target.mean(), q_1=qs[0].mean(), q_2=qs[1].mean(), )) info.update(self._aux_info(experience_batch, policy_output)) info = map_dict(ten2ar, info) self._update_steps += 1 return info