示例#1
0
文件: apex.py 项目: alipay/ray
        def wait_on_replay_actors(timeout: float) -> None:
            """Wait for the replay actors to finish sampling for timeout seconds.
            If the timeout is None, then block on the actors indefinitely.
            """
            replay_samples_ready: Dict[ActorHandle,
                                       T] = wait_asynchronous_requests(
                                           remote_requests_in_flight=self.
                                           remote_replay_requests_in_flight,
                                           ray_wait_timeout_s=timeout,
                                       )

            for replay_actor, sample_batches in replay_samples_ready.items():
                for sample_batch in sample_batches:
                    self.replay_sample_batches.append(
                        (replay_actor, sample_batch))
示例#2
0
    def process_experiences_tree_aggregation(
        self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]]
    ) -> Union[SampleBatchType, None]:
        batches = [
            sample_batch_ref
            for refs_batch in actor_to_sample_batches_refs.values()
            for sample_batch_ref in refs_batch
        ]
        ready_processed_batches = []
        for batch in batches:
            aggregator = random.choice(self.aggregator_workers)
            processed_sample_batches: Dict[
                ActorHandle, List[ObjectRef]] = asynchronous_parallel_requests(
                    remote_requests_in_flight=self.
                    remote_aggregator_requests_in_flight,
                    actors=[aggregator],
                    remote_fn=lambda actor, b: actor.process_episodes(b),
                    remote_kwargs=[{
                        "b": batch
                    }],
                    ray_wait_timeout_s=self.config["aggregator_wait_timeout"],
                    max_remote_requests_in_flight_per_actor=float("inf"),
                )
            for ready_sub_batches in processed_sample_batches.values():
                ready_processed_batches.extend(ready_sub_batches)

        waiting_processed_sample_batches: Dict[
            ActorHandle, List[ObjectRef]] = wait_asynchronous_requests(
                remote_requests_in_flight=self.
                remote_aggregator_requests_in_flight,
                ray_wait_timeout_s=self.config["aggregator_wait_timeout"],
            )
        for ready_sub_batches in waiting_processed_sample_batches.values():
            ready_processed_batches.extend(ready_sub_batches)

        return ready_processed_batches