def act(self, obs): """Output dict contains is_hl_step in case high-level action was performed during this action.""" obs_input = obs[None] if len( obs.shape) == 1 else obs # need batch input for agents output = AttrDict() if self._perform_hl_step_now: # perform step with high-level policy self._last_hl_output = self.hl_agent.act(obs_input) output.is_hl_step = True if len(obs_input.shape) == 2 and len( self._last_hl_output.action.shape) == 1: self._last_hl_output.action = self._last_hl_output.action[ None] # add batch dim if necessary self._last_hl_output.log_prob = self._last_hl_output.log_prob[ None] else: output.is_hl_step = False output.update(prefix_dict(self._last_hl_output, 'hl_')) # perform step with low-level policy assert self._last_hl_output is not None output.update( self.ll_agent.act( self.make_ll_obs(obs_input, self._last_hl_output.action))) return self._remove_batch(output) if len(obs.shape) == 1 else output
def update(self, experience_batches): """Updates high-level and low-level agents depending on which parameters are set.""" assert isinstance(experience_batches, AttrDict) # update requires batches for both HL and LL update_outputs = AttrDict() if self._hp.update_hl: hl_update_outputs = self.hl_agent.update(experience_batches.hl_batch) update_outputs.update(prefix_dict(hl_update_outputs, "hl_")) if self._hp.update_ll: ll_update_outputs = self.ll_agent.update(experience_batches.ll_batch) update_outputs.update(ll_update_outputs) return update_outputs
def log_scalar_dict(self, d, prefix='', step=None): """Logs all entries from a dict of scalars. Optionally can prefix all keys in dict before logging.""" if prefix: d = prefix_dict(d, prefix + '_') wandb.log(d) if step is None else wandb.log(d, step=step)