def build_and_reset( self, episode: Optional[MultiAgentEpisode] = None) -> MultiAgentBatch: """Returns the accumulated sample batches for each policy. Any unprocessed rows will be first postprocessed with a policy postprocessor. The internal state of this builder will be reset. Args: episode (Optional[MultiAgentEpisode]): The Episode object that holds this MultiAgentBatchBuilder object or None. Returns: MultiAgentBatch: Returns the accumulated sample batches for each policy. """ self.postprocess_batch_so_far(episode) policy_batches = {} for policy_id, builder in self.policy_builders.items(): if builder.count > 0: policy_batches[policy_id] = builder.build_and_reset() old_count = self.count self.count = 0 return MultiAgentBatch.wrap_as_needed(policy_batches, old_count)
def build_multi_agent_batch(self, env_steps: int) -> \ Union[MultiAgentBatch, SampleBatch]: ma_batch = MultiAgentBatch.wrap_as_needed( { pid: collector.build() for pid, collector in self.policy_collectors.items() if collector.count > 0 }, env_steps=env_steps) self.policy_collectors_env_steps = 0 return ma_batch
def get_multi_agent_batch_and_reset(self): self.postprocess_trajectories_so_far() policy_batches = {} for pid, rc in self.policy_sample_collectors.items(): policy = self.policy_map[pid] view_reqs = policy.training_view_requirements policy_batches[pid] = rc.get_train_sample_batch_and_reset( view_reqs) ma_batch = MultiAgentBatch.wrap_as_needed(policy_batches, self.count) # Reset our across-all-agents env step count. self.count = 0 return ma_batch
def _build_multi_agent_batch(self, episode: MultiAgentEpisode) -> \ Union[MultiAgentBatch, SampleBatch]: ma_batch = {} for pid, collector in episode.batch_builder.policy_collectors.items(): if collector.count > 0: ma_batch[pid] = collector.build() # Create the batch. ma_batch = MultiAgentBatch.wrap_as_needed( ma_batch, env_steps=episode.batch_builder.count) # PolicyCollectorGroup is empty. episode.batch_builder.count = 0 return ma_batch