Example #1
0
    def act(self, obs):
        """Output dict contains is_hl_step in case high-level action was performed during this action."""
        obs_input = obs[None] if len(
            obs.shape) == 1 else obs  # need batch input for agents
        output = AttrDict()
        if self._perform_hl_step_now:
            # perform step with high-level policy
            self._last_hl_output = self.hl_agent.act(obs_input)
            output.is_hl_step = True
            if len(obs_input.shape) == 2 and len(
                    self._last_hl_output.action.shape) == 1:
                self._last_hl_output.action = self._last_hl_output.action[
                    None]  # add batch dim if necessary
                self._last_hl_output.log_prob = self._last_hl_output.log_prob[
                    None]
        else:
            output.is_hl_step = False
        output.update(prefix_dict(self._last_hl_output, 'hl_'))

        # perform step with low-level policy
        assert self._last_hl_output is not None
        output.update(
            self.ll_agent.act(
                self.make_ll_obs(obs_input, self._last_hl_output.action)))

        return self._remove_batch(output) if len(obs.shape) == 1 else output
Example #2
0
 def get_episode_info(self):
     episode_info = AttrDict(
         episode_reward=self._episode_reward,
         episode_length=self._episode_step,
     )
     if hasattr(self._env, "get_episode_info"):
         episode_info.update(self._env.get_episode_info())
     return episode_info
Example #3
0
    def get_episode_info(self):
        episode_info = AttrDict()

        flag_names = ['_1_grasped', '_2_lift', '_4_stack', '_5_stack_final']
        flag_values = [self._grasped_flag, self._lifted_flag, self._stacked_flag, self._stacked_final_flag]
        for i in range(len(flag_names)):
            episode_info.update({"block{}".format(flag_names[i]):
                sum([int(flag_values[i][task_idx]) for task_idx in range(len(self._task))])})

        return episode_info
Example #4
0
    def get_episode_info(self):
        episode_info = AttrDict()

        flag_names = ['_1_reach', '_2_lift', '_3_deliver', '_4_stack']
        flag_values = [self._reached_flag, self._lifted_flag,
                       self._delivered_flag, self._stacked_flag]
        for i in range(len(flag_names)):
            episode_info.update({"block{}".format(flag_names[i]): 
                sum([int(flag_values[i][task_idx]) for task_idx in range(len(self._task))])})

        return episode_info
Example #5
0
 def update(self, experience_batches):
     """Updates high-level and low-level agents depending on which parameters are set."""
     assert isinstance(experience_batches, AttrDict)  # update requires batches for both HL and LL
     update_outputs = AttrDict()
     if self._hp.update_hl:
         hl_update_outputs = self.hl_agent.update(experience_batches.hl_batch)
         update_outputs.update(prefix_dict(hl_update_outputs, "hl_"))
     if self._hp.update_ll:
         ll_update_outputs = self.ll_agent.update(experience_batches.ll_batch)
         update_outputs.update(ll_update_outputs)
     return update_outputs
Example #6
0
    def update(self, experience_batch):
        """Updates actor and critics."""
        # push experience batch into replay buffer
        self.add_experience(experience_batch)

        for _ in range(self._hp.update_iterations):
            # sample batch and normalize
            experience_batch = self._sample_experience()
            experience_batch = self._normalize_batch(experience_batch)
            experience_batch = map2torch(experience_batch, self._hp.device)
            experience_batch = self._preprocess_experience(experience_batch)

            policy_output = self._run_policy(experience_batch.observation)

            # update alpha
            alpha_loss = self._update_alpha(experience_batch, policy_output)

            # compute policy loss
            policy_loss = self._compute_policy_loss(experience_batch,
                                                    policy_output)

            # compute target Q value
            with torch.no_grad():
                policy_output_next = self._run_policy(
                    experience_batch.observation_next)
                value_next = self._compute_next_value(experience_batch,
                                                      policy_output_next)
                q_target = experience_batch.reward * self._hp.reward_scale + \
                                (1 - experience_batch.done) * self._hp.discount_factor * value_next
                if self._hp.clip_q_target:
                    q_target = self._clip_q_target(q_target)
                q_target = q_target.detach()
                check_shape(q_target, [self._hp.batch_size])

            # compute critic loss
            critic_losses, qs = self._compute_critic_loss(
                experience_batch, q_target)

            # update critic networks
            [
                self._perform_update(critic_loss, critic_opt, critic)
                for critic_loss, critic_opt, critic in zip(
                    critic_losses, self.critic_opts, self.critics)
            ]

            # update target networks
            [
                self._soft_update_target_network(critic_target, critic) for
                critic_target, critic in zip(self.critic_targets, self.critics)
            ]

            # update policy network on policy loss
            self._perform_update(policy_loss, self.policy_opt, self.policy)

            # logging
            info = AttrDict(  # losses
                policy_loss=policy_loss,
                alpha_loss=alpha_loss,
                critic_loss_1=critic_losses[0],
                critic_loss_2=critic_losses[1],
            )
            if self._update_steps % 100 == 0:
                info.update(
                    AttrDict(  # gradient norms
                        policy_grad_norm=avg_grad_norm(self.policy),
                        critic_1_grad_norm=avg_grad_norm(self.critics[0]),
                        critic_2_grad_norm=avg_grad_norm(self.critics[1]),
                    ))
            info.update(
                AttrDict(  # misc
                    alpha=self.alpha,
                    pi_log_prob=policy_output.log_prob.mean(),
                    policy_entropy=policy_output.dist.entropy().mean(),
                    q_target=q_target.mean(),
                    q_1=qs[0].mean(),
                    q_2=qs[1].mean(),
                ))
            info.update(self._aux_info(experience_batch, policy_output))
            info = map_dict(ten2ar, info)

            self._update_steps += 1

        return info
Example #7
0
    nz_vae=10,
    nz_mid=128,
    n_processing_layers=5,
    kl_div_weight=5e-4,
    cond_decode=True,
)

# LL Policy
ll_policy_params = AttrDict(
    policy_model=ClSPiRLMdl,
    policy_model_params=ll_model_params,
    policy_model_checkpoint=os.path.join(
        os.environ["EXP_DIR"], "skill_prior_learning/office/hierarchical_cl"),
    initial_log_sigma=-50,
)
ll_policy_params.update(ll_model_params)

# LL Critic
ll_critic_params = AttrDict(
    action_dim=data_spec.n_actions,
    input_dim=data_spec.state_dim,
    output_dim=1,
)

# LL Agent
ll_agent_config = copy.deepcopy(base_agent_params)
ll_agent_config.update(
    AttrDict(
        policy=ClModelPolicy,
        policy_params=ll_policy_params,
        critic=MLPCritic,  # LL critic is not used since we are not finetuning LL