コード例 #1
0
ファイル: agent.py プロジェクト: xiaofei-w/spirl
    def act(self, obs):
        """Output dict contains is_hl_step in case high-level action was performed during this action."""
        obs_input = obs[None] if len(
            obs.shape) == 1 else obs  # need batch input for agents
        output = AttrDict()
        if self._perform_hl_step_now:
            # perform step with high-level policy
            self._last_hl_output = self.hl_agent.act(obs_input)
            output.is_hl_step = True
            if len(obs_input.shape) == 2 and len(
                    self._last_hl_output.action.shape) == 1:
                self._last_hl_output.action = self._last_hl_output.action[
                    None]  # add batch dim if necessary
                self._last_hl_output.log_prob = self._last_hl_output.log_prob[
                    None]
        else:
            output.is_hl_step = False
        output.update(prefix_dict(self._last_hl_output, 'hl_'))

        # perform step with low-level policy
        assert self._last_hl_output is not None
        output.update(
            self.ll_agent.act(
                self.make_ll_obs(obs_input, self._last_hl_output.action)))

        return self._remove_batch(output) if len(obs.shape) == 1 else output
コード例 #2
0
ファイル: agent.py プロジェクト: clvrai/spirl
 def update(self, experience_batches):
     """Updates high-level and low-level agents depending on which parameters are set."""
     assert isinstance(experience_batches, AttrDict)  # update requires batches for both HL and LL
     update_outputs = AttrDict()
     if self._hp.update_hl:
         hl_update_outputs = self.hl_agent.update(experience_batches.hl_batch)
         update_outputs.update(prefix_dict(hl_update_outputs, "hl_"))
     if self._hp.update_ll:
         ll_update_outputs = self.ll_agent.update(experience_batches.ll_batch)
         update_outputs.update(ll_update_outputs)
     return update_outputs
コード例 #3
0
ファイル: wandb.py プロジェクト: clvrai/spirl
 def log_scalar_dict(self, d, prefix='', step=None):
     """Logs all entries from a dict of scalars. Optionally can prefix all keys in dict before logging."""
     if prefix: d = prefix_dict(d, prefix + '_')
     wandb.log(d) if step is None else wandb.log(d, step=step)