def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, episode: Optional["Episode"] = None, ): sample_batch = super().postprocess_trajectory(sample_batch, other_agent_batches, episode) # Trajectory is actually complete -> last r=0.0. if sample_batch[SampleBatch.DONES][-1]: last_r = 0.0 # Trajectory has been truncated -> last r=VF estimate of last obs. else: # Input dict is provided to us automatically via the Model's # requirements. It's a single-timestep (last one in trajectory) # input_dict. # Create an input dict according to the Model's requirements. index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1 input_dict = sample_batch.get_single_step_input_dict( self.model.view_requirements, index=index) last_r = self._value(**input_dict) # Adds the "advantages" (which in the case of MARWIL are simply the # discounted cumulative rewards) to the SampleBatch. return compute_advantages( sample_batch, last_r, self.config["gamma"], # We just want the discounted cumulative rewards, so we won't need # GAE nor critic (use_critic=True: Subtract vf-estimates from returns). use_gae=False, use_critic=False, )
def compute_gae_for_sample_batch( policy: Policy, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, episode: Optional[Episode] = None, ) -> SampleBatch: """Adds GAE (generalized advantage estimations) to a trajectory. The trajectory contains only data from one episode and from one agent. - If `config.batch_mode=truncate_episodes` (default), sample_batch may contain a truncated (at-the-end) episode, in case the `config.rollout_fragment_length` was reached by the sampler. - If `config.batch_mode=complete_episodes`, sample_batch will contain exactly one episode (no matter how long). New columns can be added to sample_batch and existing ones may be altered. Args: policy: The Policy used to generate the trajectory (`sample_batch`) sample_batch: The SampleBatch to postprocess. other_agent_batches: Optional dict of AgentIDs mapping to other agents' trajectory data (from the same episode). NOTE: The other agents use the same policy. episode: Optional multi-agent episode object in which the agents operated. Returns: The postprocessed, modified SampleBatch (or a new one). """ # Trajectory is actually complete -> last r=0.0. if sample_batch[SampleBatch.DONES][-1]: last_r = 0.0 # Trajectory has been truncated -> last r=VF estimate of last obs. else: # Input dict is provided to us automatically via the Model's # requirements. It's a single-timestep (last one in trajectory) # input_dict. # Create an input dict according to the Model's requirements. input_dict = sample_batch.get_single_step_input_dict( policy.model.view_requirements, index="last" ) last_r = policy._value(**input_dict) # Adds the policy logits, VF preds, and advantages to the batch, # using GAE ("generalized advantage estimation") or not. batch = compute_advantages( sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], use_gae=policy.config["use_gae"], use_critic=policy.config.get("use_critic", True), ) return batch
def postprocess_advantages( policy: Policy, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[PolicyID, SampleBatch]] = None, episode=None, ) -> SampleBatch: """Postprocesses a trajectory and returns the processed trajectory. The trajectory contains only data from one episode and from one agent. - If `config.batch_mode=truncate_episodes` (default), sample_batch may contain a truncated (at-the-end) episode, in case the `config.rollout_fragment_length` was reached by the sampler. - If `config.batch_mode=complete_episodes`, sample_batch will contain exactly one episode (no matter how long). New columns can be added to sample_batch and existing ones may be altered. Args: policy (Policy): The Policy used to generate the trajectory (`sample_batch`) sample_batch (SampleBatch): The SampleBatch to postprocess. other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional dict of AgentIDs mapping to other agents' trajectory data (from the same episode). NOTE: The other agents use the same policy. episode (Optional[Episode]): Optional multi-agent episode object in which the agents operated. Returns: SampleBatch: The postprocessed, modified SampleBatch (or a new one). """ # Trajectory is actually complete -> last r=0.0. if sample_batch[SampleBatch.DONES][-1]: last_r = 0.0 # Trajectory has been truncated -> last r=VF estimate of last obs. else: # Input dict is provided to us automatically via the Model's # requirements. It's a single-timestep (last one in trajectory) # input_dict. # Create an input dict according to the Model's requirements. index = "last" if SampleBatch.NEXT_OBS in sample_batch else -1 input_dict = sample_batch.get_single_step_input_dict( policy.model.view_requirements, index=index) last_r = policy._value(**input_dict) # Adds the "advantages" (which in the case of MARWIL are simply the # discounted cummulative rewards) to the SampleBatch. return compute_advantages( sample_batch, last_r, policy.config["gamma"], # We just want the discounted cummulative rewards, so we won't need # GAE nor critic (use_critic=True: Subtract vf-estimates from returns). use_gae=False, use_critic=False, )
def test_get_single_step_input_dict_batch_repeat_value_1(self): """Test whether a SampleBatch produces the correct 1-step input dict.""" space = Box(-1.0, 1.0, ()) # With batch-repeat-value==1: state_in_0 is built each timestep. view_reqs = { "state_in_0": ViewRequirement( data_col="state_out_0", shift="-5:-1", space=space, batch_repeat_value=1, ), "state_out_0": ViewRequirement(space=space, used_for_compute_actions=False), } # Trajectory of 1 ts (0) (we would like to compute the 1st). batch = SampleBatch({ "state_in_0": np.array([ [0, 0, 0, 0, 0], # ts=0 ]), "state_out_0": np.array([1]), }) input_dict = batch.get_single_step_input_dict( view_requirements=view_reqs, index="last") check( input_dict, { "state_in_0": [[0, 0, 0, 0, 1]], # ts=1 "seq_lens": [1], }, ) # Trajectory of 6 ts (0-5) (we would like to compute the 6th). batch = SampleBatch({ "state_in_0": np.array([ [0, 0, 0, 0, 0], # ts=0 [0, 0, 0, 0, 1], # ts=1 [0, 0, 0, 1, 2], # ts=2 [0, 0, 1, 2, 3], # ts=3 [0, 1, 2, 3, 4], # ts=4 [1, 2, 3, 4, 5], # ts=5 ]), "state_out_0": np.array([1, 2, 3, 4, 5, 6]), }) input_dict = batch.get_single_step_input_dict( view_requirements=view_reqs, index="last") check( input_dict, { "state_in_0": [[2, 3, 4, 5, 6]], # ts=6 "seq_lens": [1], }, ) # Trajectory of 12 ts (0-11) (we would like to compute the 12th). batch = SampleBatch({ "state_in_0": np.array([ [0, 0, 0, 0, 0], # ts=0 [0, 0, 0, 0, 1], # ts=1 [0, 0, 0, 1, 2], # ts=2 [0, 0, 1, 2, 3], # ts=3 [0, 1, 2, 3, 4], # ts=4 [1, 2, 3, 4, 5], # ts=5 [2, 3, 4, 5, 6], # ts=6 [3, 4, 5, 6, 7], # ts=7 [4, 5, 6, 7, 8], # ts=8 [5, 6, 7, 8, 9], # ts=9 [6, 7, 8, 9, 10], # ts=10 [7, 8, 9, 10, 11], # ts=11 ]), "state_out_0": np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), }) input_dict = batch.get_single_step_input_dict( view_requirements=view_reqs, index="last") check( input_dict, { "state_in_0": [[8, 9, 10, 11, 12]], # ts=12 "seq_lens": [1], }, )