def wait_on_replay_actors(timeout: float) -> None: """Wait for the replay actors to finish sampling for timeout seconds. If the timeout is None, then block on the actors indefinitely. """ replay_samples_ready: Dict[ActorHandle, T] = wait_asynchronous_requests( remote_requests_in_flight=self. remote_replay_requests_in_flight, ray_wait_timeout_s=timeout, ) for replay_actor, sample_batches in replay_samples_ready.items(): for sample_batch in sample_batches: self.replay_sample_batches.append( (replay_actor, sample_batch))
def process_experiences_tree_aggregation( self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]] ) -> Union[SampleBatchType, None]: batches = [ sample_batch_ref for refs_batch in actor_to_sample_batches_refs.values() for sample_batch_ref in refs_batch ] ready_processed_batches = [] for batch in batches: aggregator = random.choice(self.aggregator_workers) processed_sample_batches: Dict[ ActorHandle, List[ObjectRef]] = asynchronous_parallel_requests( remote_requests_in_flight=self. remote_aggregator_requests_in_flight, actors=[aggregator], remote_fn=lambda actor, b: actor.process_episodes(b), remote_kwargs=[{ "b": batch }], ray_wait_timeout_s=self.config["aggregator_wait_timeout"], max_remote_requests_in_flight_per_actor=float("inf"), ) for ready_sub_batches in processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) waiting_processed_sample_batches: Dict[ ActorHandle, List[ObjectRef]] = wait_asynchronous_requests( remote_requests_in_flight=self. remote_aggregator_requests_in_flight, ray_wait_timeout_s=self.config["aggregator_wait_timeout"], ) for ready_sub_batches in waiting_processed_sample_batches.values(): ready_processed_batches.extend(ready_sub_batches) return ready_processed_batches