def add_experiences( self, decision_steps: DecisionSteps, terminal_steps: TerminalSteps, worker_id: int, previous_action: ActionInfo, ) -> None: """ Adds experiences to each agent's experience history. :param decision_steps: current DecisionSteps. :param terminal_steps: current TerminalSteps. :param previous_action: The outputs of the Policy's get_action method. """ take_action_outputs = previous_action.outputs if take_action_outputs: for _entropy in take_action_outputs["entropy"]: self._stats_reporter.add_stat("Policy/Entropy", _entropy) # Make unique agent_ids that are global across workers action_global_agent_ids = [ get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids ] for global_id in action_global_agent_ids: if global_id in self._last_step_result: # Don't store if agent just reset self._last_take_action_outputs[global_id] = take_action_outputs # Iterate over all the terminal steps, first gather all the group obs # and then create the AgentExperiences/Trajectories. _add_to_group_status # stores Group statuses in a common data structure self.group_status for terminal_step in terminal_steps.values(): self._add_group_status_and_obs(terminal_step, worker_id) for terminal_step in terminal_steps.values(): local_id = terminal_step.agent_id global_id = get_global_agent_id(worker_id, local_id) self._process_step(terminal_step, worker_id, terminal_steps.agent_id_to_index[local_id]) # Clear the last seen group obs when agents die. self._clear_group_status_and_obs(global_id) # Iterate over all the decision steps, first gather all the group obs # and then create the trajectories. _add_to_group_status # stores Group statuses in a common data structure self.group_status for ongoing_step in decision_steps.values(): self._add_group_status_and_obs(ongoing_step, worker_id) for ongoing_step in decision_steps.values(): local_id = ongoing_step.agent_id self._process_step(ongoing_step, worker_id, decision_steps.agent_id_to_index[local_id]) for _gid in action_global_agent_ids: # If the ID doesn't have a last step result, the agent just reset, # don't store the action. if _gid in self._last_step_result: if "action" in take_action_outputs: self.policy.save_previous_action( [_gid], take_action_outputs["action"])
def add_experiences( self, decision_steps: DecisionSteps, terminal_steps: TerminalSteps, worker_id: int, previous_action: ActionInfo, ) -> None: """ Adds experiences to each agent's experience history. :param decision_steps: current DecisionSteps. :param terminal_steps: current TerminalSteps. :param previous_action: The outputs of the Policy's get_action method. """ take_action_outputs = previous_action.outputs if take_action_outputs: for _entropy in take_action_outputs["entropy"]: self.stats_reporter.add_stat("Policy/Entropy", _entropy) # Make unique agent_ids that are global across workers action_global_agent_ids = [ get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids ] for global_id in action_global_agent_ids: if global_id in self.last_step_result: # Don't store if agent just reset self.last_take_action_outputs[global_id] = take_action_outputs # Iterate over all the terminal steps for terminal_step in terminal_steps.values(): local_id = terminal_step.agent_id global_id = get_global_agent_id(worker_id, local_id) self._process_step( terminal_step, global_id, terminal_steps.agent_id_to_index[local_id] ) # Iterate over all the decision steps for ongoing_step in decision_steps.values(): local_id = ongoing_step.agent_id global_id = get_global_agent_id(worker_id, local_id) self._process_step( ongoing_step, global_id, decision_steps.agent_id_to_index[local_id] ) for _gid in action_global_agent_ids: # If the ID doesn't have a last step result, the agent just reset, # don't store the action. if _gid in self.last_step_result: if "action" in take_action_outputs: self.policy.save_previous_action( [_gid], take_action_outputs["action"] )