def act(self, obs): """Output dict contains is_hl_step in case high-level action was performed during this action.""" obs_input = obs[None] if len( obs.shape) == 1 else obs # need batch input for agents output = AttrDict() if self._perform_hl_step_now: # perform step with high-level policy self._last_hl_output = self.hl_agent.act(obs_input) output.is_hl_step = True if len(obs_input.shape) == 2 and len( self._last_hl_output.action.shape) == 1: self._last_hl_output.action = self._last_hl_output.action[ None] # add batch dim if necessary self._last_hl_output.log_prob = self._last_hl_output.log_prob[ None] else: output.is_hl_step = False output.update(prefix_dict(self._last_hl_output, 'hl_')) # perform step with low-level policy assert self._last_hl_output is not None output.update( self.ll_agent.act( self.make_ll_obs(obs_input, self._last_hl_output.action))) return self._remove_batch(output) if len(obs.shape) == 1 else output
def get_episode_info(self): episode_info = AttrDict( episode_reward=self._episode_reward, episode_length=self._episode_step, ) if hasattr(self._env, "get_episode_info"): episode_info.update(self._env.get_episode_info()) return episode_info
def get_episode_info(self): episode_info = AttrDict() flag_names = ['_1_grasped', '_2_lift', '_4_stack', '_5_stack_final'] flag_values = [self._grasped_flag, self._lifted_flag, self._stacked_flag, self._stacked_final_flag] for i in range(len(flag_names)): episode_info.update({"block{}".format(flag_names[i]): sum([int(flag_values[i][task_idx]) for task_idx in range(len(self._task))])}) return episode_info
def get_episode_info(self): episode_info = AttrDict() flag_names = ['_1_reach', '_2_lift', '_3_deliver', '_4_stack'] flag_values = [self._reached_flag, self._lifted_flag, self._delivered_flag, self._stacked_flag] for i in range(len(flag_names)): episode_info.update({"block{}".format(flag_names[i]): sum([int(flag_values[i][task_idx]) for task_idx in range(len(self._task))])}) return episode_info
def update(self, experience_batches): """Updates high-level and low-level agents depending on which parameters are set.""" assert isinstance(experience_batches, AttrDict) # update requires batches for both HL and LL update_outputs = AttrDict() if self._hp.update_hl: hl_update_outputs = self.hl_agent.update(experience_batches.hl_batch) update_outputs.update(prefix_dict(hl_update_outputs, "hl_")) if self._hp.update_ll: ll_update_outputs = self.ll_agent.update(experience_batches.ll_batch) update_outputs.update(ll_update_outputs) return update_outputs
def update(self, experience_batch): """Updates actor and critics.""" # push experience batch into replay buffer self.add_experience(experience_batch) for _ in range(self._hp.update_iterations): # sample batch and normalize experience_batch = self._sample_experience() experience_batch = self._normalize_batch(experience_batch) experience_batch = map2torch(experience_batch, self._hp.device) experience_batch = self._preprocess_experience(experience_batch) policy_output = self._run_policy(experience_batch.observation) # update alpha alpha_loss = self._update_alpha(experience_batch, policy_output) # compute policy loss policy_loss = self._compute_policy_loss(experience_batch, policy_output) # compute target Q value with torch.no_grad(): policy_output_next = self._run_policy( experience_batch.observation_next) value_next = self._compute_next_value(experience_batch, policy_output_next) q_target = experience_batch.reward * self._hp.reward_scale + \ (1 - experience_batch.done) * self._hp.discount_factor * value_next if self._hp.clip_q_target: q_target = self._clip_q_target(q_target) q_target = q_target.detach() check_shape(q_target, [self._hp.batch_size]) # compute critic loss critic_losses, qs = self._compute_critic_loss( experience_batch, q_target) # update critic networks [ self._perform_update(critic_loss, critic_opt, critic) for critic_loss, critic_opt, critic in zip( critic_losses, self.critic_opts, self.critics) ] # update target networks [ self._soft_update_target_network(critic_target, critic) for critic_target, critic in zip(self.critic_targets, self.critics) ] # update policy network on policy loss self._perform_update(policy_loss, self.policy_opt, self.policy) # logging info = AttrDict( # losses policy_loss=policy_loss, alpha_loss=alpha_loss, critic_loss_1=critic_losses[0], critic_loss_2=critic_losses[1], ) if self._update_steps % 100 == 0: info.update( AttrDict( # gradient norms policy_grad_norm=avg_grad_norm(self.policy), critic_1_grad_norm=avg_grad_norm(self.critics[0]), critic_2_grad_norm=avg_grad_norm(self.critics[1]), )) info.update( AttrDict( # misc alpha=self.alpha, pi_log_prob=policy_output.log_prob.mean(), policy_entropy=policy_output.dist.entropy().mean(), q_target=q_target.mean(), q_1=qs[0].mean(), q_2=qs[1].mean(), )) info.update(self._aux_info(experience_batch, policy_output)) info = map_dict(ten2ar, info) self._update_steps += 1 return info
nz_vae=10, nz_mid=128, n_processing_layers=5, kl_div_weight=5e-4, cond_decode=True, ) # LL Policy ll_policy_params = AttrDict( policy_model=ClSPiRLMdl, policy_model_params=ll_model_params, policy_model_checkpoint=os.path.join( os.environ["EXP_DIR"], "skill_prior_learning/office/hierarchical_cl"), initial_log_sigma=-50, ) ll_policy_params.update(ll_model_params) # LL Critic ll_critic_params = AttrDict( action_dim=data_spec.n_actions, input_dim=data_spec.state_dim, output_dim=1, ) # LL Agent ll_agent_config = copy.deepcopy(base_agent_params) ll_agent_config.update( AttrDict( policy=ClModelPolicy, policy_params=ll_policy_params, critic=MLPCritic, # LL critic is not used since we are not finetuning LL