Exemplo n.º 1
0
    def add_experiences(
        self,
        decision_steps: DecisionSteps,
        terminal_steps: TerminalSteps,
        worker_id: int,
        previous_action: ActionInfo,
    ) -> None:
        """
        Adds experiences to each agent's experience history.
        :param decision_steps: current DecisionSteps.
        :param terminal_steps: current TerminalSteps.
        :param previous_action: The outputs of the Policy's get_action method.
        """
        take_action_outputs = previous_action.outputs
        if take_action_outputs:
            for _entropy in take_action_outputs["entropy"]:
                self._stats_reporter.add_stat("Policy/Entropy", _entropy)

        # Make unique agent_ids that are global across workers
        action_global_agent_ids = [
            get_global_agent_id(worker_id, ag_id)
            for ag_id in previous_action.agent_ids
        ]
        for global_id in action_global_agent_ids:
            if global_id in self._last_step_result:  # Don't store if agent just reset
                self._last_take_action_outputs[global_id] = take_action_outputs

        # Iterate over all the terminal steps, first gather all the group obs
        # and then create the AgentExperiences/Trajectories. _add_to_group_status
        # stores Group statuses in a common data structure self.group_status
        for terminal_step in terminal_steps.values():
            self._add_group_status_and_obs(terminal_step, worker_id)
        for terminal_step in terminal_steps.values():
            local_id = terminal_step.agent_id
            global_id = get_global_agent_id(worker_id, local_id)
            self._process_step(terminal_step, worker_id,
                               terminal_steps.agent_id_to_index[local_id])
            # Clear the last seen group obs when agents die.
            self._clear_group_status_and_obs(global_id)

        # Iterate over all the decision steps, first gather all the group obs
        # and then create the trajectories. _add_to_group_status
        # stores Group statuses in a common data structure self.group_status
        for ongoing_step in decision_steps.values():
            self._add_group_status_and_obs(ongoing_step, worker_id)
        for ongoing_step in decision_steps.values():
            local_id = ongoing_step.agent_id
            self._process_step(ongoing_step, worker_id,
                               decision_steps.agent_id_to_index[local_id])

        for _gid in action_global_agent_ids:
            # If the ID doesn't have a last step result, the agent just reset,
            # don't store the action.
            if _gid in self._last_step_result:
                if "action" in take_action_outputs:
                    self.policy.save_previous_action(
                        [_gid], take_action_outputs["action"])
Exemplo n.º 2
0
    def add_experiences(
        self,
        decision_steps: DecisionSteps,
        terminal_steps: TerminalSteps,
        worker_id: int,
        previous_action: ActionInfo,
    ) -> None:
        """
        Adds experiences to each agent's experience history.
        :param decision_steps: current DecisionSteps.
        :param terminal_steps: current TerminalSteps.
        :param previous_action: The outputs of the Policy's get_action method.
        """
        take_action_outputs = previous_action.outputs
        if take_action_outputs:
            for _entropy in take_action_outputs["entropy"]:
                self.stats_reporter.add_stat("Policy/Entropy", _entropy)

        # Make unique agent_ids that are global across workers
        action_global_agent_ids = [
            get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
        ]
        for global_id in action_global_agent_ids:
            if global_id in self.last_step_result:  # Don't store if agent just reset
                self.last_take_action_outputs[global_id] = take_action_outputs

        # Iterate over all the terminal steps
        for terminal_step in terminal_steps.values():
            local_id = terminal_step.agent_id
            global_id = get_global_agent_id(worker_id, local_id)
            self._process_step(
                terminal_step, global_id, terminal_steps.agent_id_to_index[local_id]
            )
        # Iterate over all the decision steps
        for ongoing_step in decision_steps.values():
            local_id = ongoing_step.agent_id
            global_id = get_global_agent_id(worker_id, local_id)
            self._process_step(
                ongoing_step, global_id, decision_steps.agent_id_to_index[local_id]
            )

        for _gid in action_global_agent_ids:
            # If the ID doesn't have a last step result, the agent just reset,
            # don't store the action.
            if _gid in self.last_step_result:
                if "action" in take_action_outputs:
                    self.policy.save_previous_action(
                        [_gid], take_action_outputs["action"]
                    )