示例#1
0
    def postprocess_trajectory(
        self,
        sample_batch: SampleBatch,
        other_agent_batches: Optional[Dict[Any, SampleBatch]] = None,
        episode: Optional["Episode"] = None,
    ):
        sample_batch = super().postprocess_trajectory(sample_batch,
                                                      other_agent_batches,
                                                      episode)

        # Trajectory is actually complete -> last r=0.0.
        if sample_batch[SampleBatch.DONES][-1]:
            last_r = 0.0
        # Trajectory has been truncated -> last r=VF estimate of last obs.
        else:
            # Input dict is provided to us automatically via the Model's
            # requirements. It's a single-timestep (last one in trajectory)
            # input_dict.
            # Create an input dict according to the Model's requirements.
            index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
            input_dict = sample_batch.get_single_step_input_dict(
                self.model.view_requirements, index=index)
            last_r = self._value(**input_dict)

        # Adds the "advantages" (which in the case of MARWIL are simply the
        # discounted cumulative rewards) to the SampleBatch.
        return compute_advantages(
            sample_batch,
            last_r,
            self.config["gamma"],
            # We just want the discounted cumulative rewards, so we won't need
            # GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
            use_gae=False,
            use_critic=False,
        )
示例#2
0
def compute_gae_for_sample_batch(
    policy: Policy,
    sample_batch: SampleBatch,
    other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
    episode: Optional[Episode] = None,
) -> SampleBatch:
    """Adds GAE (generalized advantage estimations) to a trajectory.

    The trajectory contains only data from one episode and from one agent.
    - If  `config.batch_mode=truncate_episodes` (default), sample_batch may
    contain a truncated (at-the-end) episode, in case the
    `config.rollout_fragment_length` was reached by the sampler.
    - If `config.batch_mode=complete_episodes`, sample_batch will contain
    exactly one episode (no matter how long).
    New columns can be added to sample_batch and existing ones may be altered.

    Args:
        policy: The Policy used to generate the trajectory (`sample_batch`)
        sample_batch: The SampleBatch to postprocess.
        other_agent_batches: Optional dict of AgentIDs mapping to other
            agents' trajectory data (from the same episode).
            NOTE: The other agents use the same policy.
        episode: Optional multi-agent episode object in which the agents
            operated.

    Returns:
        The postprocessed, modified SampleBatch (or a new one).
    """

    # Trajectory is actually complete -> last r=0.0.
    if sample_batch[SampleBatch.DONES][-1]:
        last_r = 0.0
    # Trajectory has been truncated -> last r=VF estimate of last obs.
    else:
        # Input dict is provided to us automatically via the Model's
        # requirements. It's a single-timestep (last one in trajectory)
        # input_dict.
        # Create an input dict according to the Model's requirements.
        input_dict = sample_batch.get_single_step_input_dict(
            policy.model.view_requirements, index="last"
        )
        last_r = policy._value(**input_dict)

    # Adds the policy logits, VF preds, and advantages to the batch,
    # using GAE ("generalized advantage estimation") or not.
    batch = compute_advantages(
        sample_batch,
        last_r,
        policy.config["gamma"],
        policy.config["lambda"],
        use_gae=policy.config["use_gae"],
        use_critic=policy.config.get("use_critic", True),
    )

    return batch
示例#3
0
def postprocess_advantages(
    policy: Policy,
    sample_batch: SampleBatch,
    other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None,
    episode=None,
) -> SampleBatch:
    """Postprocesses a trajectory and returns the processed trajectory.

    The trajectory contains only data from one episode and from one agent.
    - If  `config.batch_mode=truncate_episodes` (default), sample_batch may
    contain a truncated (at-the-end) episode, in case the
    `config.rollout_fragment_length` was reached by the sampler.
    - If `config.batch_mode=complete_episodes`, sample_batch will contain
    exactly one episode (no matter how long).
    New columns can be added to sample_batch and existing ones may be altered.

    Args:
        policy (Policy): The Policy used to generate the trajectory
            (`sample_batch`)
        sample_batch (SampleBatch): The SampleBatch to postprocess.
        other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional
            dict of AgentIDs mapping to other agents' trajectory data (from the
            same episode). NOTE: The other agents use the same policy.
        episode (Optional[Episode]): Optional multi-agent episode
            object in which the agents operated.

    Returns:
        SampleBatch: The postprocessed, modified SampleBatch (or a new one).
    """

    # Trajectory is actually complete -> last r=0.0.
    if sample_batch[SampleBatch.DONES][-1]:
        last_r = 0.0
    # Trajectory has been truncated -> last r=VF estimate of last obs.
    else:
        # Input dict is provided to us automatically via the Model's
        # requirements. It's a single-timestep (last one in trajectory)
        # input_dict.
        # Create an input dict according to the Model's requirements.
        index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1
        input_dict = sample_batch.get_single_step_input_dict(
            policy.model.view_requirements, index=index)
        last_r = policy._value(**input_dict)

    # Adds the "advantages" (which in the case of MARWIL are simply the
    # discounted cummulative rewards) to the SampleBatch.
    return compute_advantages(
        sample_batch,
        last_r,
        policy.config["gamma"],
        # We just want the discounted cummulative rewards, so we won't need
        # GAE nor critic (use_critic=True: Subtract vf-estimates from returns).
        use_gae=False,
        use_critic=False,
    )
示例#4
0
    def test_get_single_step_input_dict_batch_repeat_value_1(self):
        """Test whether a SampleBatch produces the correct 1-step input dict."""
        space = Box(-1.0, 1.0, ())

        # With batch-repeat-value==1: state_in_0 is built each timestep.
        view_reqs = {
            "state_in_0":
            ViewRequirement(
                data_col="state_out_0",
                shift="-5:-1",
                space=space,
                batch_repeat_value=1,
            ),
            "state_out_0":
            ViewRequirement(space=space, used_for_compute_actions=False),
        }

        # Trajectory of 1 ts (0) (we would like to compute the 1st).
        batch = SampleBatch({
            "state_in_0": np.array([
                [0, 0, 0, 0, 0],  # ts=0
            ]),
            "state_out_0": np.array([1]),
        })
        input_dict = batch.get_single_step_input_dict(
            view_requirements=view_reqs, index="last")
        check(
            input_dict,
            {
                "state_in_0": [[0, 0, 0, 0, 1]],  # ts=1
                "seq_lens": [1],
            },
        )

        # Trajectory of 6 ts (0-5) (we would like to compute the 6th).
        batch = SampleBatch({
            "state_in_0":
            np.array([
                [0, 0, 0, 0, 0],  # ts=0
                [0, 0, 0, 0, 1],  # ts=1
                [0, 0, 0, 1, 2],  # ts=2
                [0, 0, 1, 2, 3],  # ts=3
                [0, 1, 2, 3, 4],  # ts=4
                [1, 2, 3, 4, 5],  # ts=5
            ]),
            "state_out_0":
            np.array([1, 2, 3, 4, 5, 6]),
        })
        input_dict = batch.get_single_step_input_dict(
            view_requirements=view_reqs, index="last")
        check(
            input_dict,
            {
                "state_in_0": [[2, 3, 4, 5, 6]],  # ts=6
                "seq_lens": [1],
            },
        )

        # Trajectory of 12 ts (0-11) (we would like to compute the 12th).
        batch = SampleBatch({
            "state_in_0":
            np.array([
                [0, 0, 0, 0, 0],  # ts=0
                [0, 0, 0, 0, 1],  # ts=1
                [0, 0, 0, 1, 2],  # ts=2
                [0, 0, 1, 2, 3],  # ts=3
                [0, 1, 2, 3, 4],  # ts=4
                [1, 2, 3, 4, 5],  # ts=5
                [2, 3, 4, 5, 6],  # ts=6
                [3, 4, 5, 6, 7],  # ts=7
                [4, 5, 6, 7, 8],  # ts=8
                [5, 6, 7, 8, 9],  # ts=9
                [6, 7, 8, 9, 10],  # ts=10
                [7, 8, 9, 10, 11],  # ts=11
            ]),
            "state_out_0":
            np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
        })
        input_dict = batch.get_single_step_input_dict(
            view_requirements=view_reqs, index="last")
        check(
            input_dict,
            {
                "state_in_0": [[8, 9, 10, 11, 12]],  # ts=12
                "seq_lens": [1],
            },
        )