Esempio n. 1
0
File: apex.py Progetto: alipay/ray
    def setup(self, config: PartialTrainerConfigDict):
        super().setup(config)

        # Shortcut: If execution_plan, thread and buffer will be created in there.
        if self.config["_disable_execution_plan_api"] is False:
            return

        # Tag those workers (top 1/3rd indices) that we should collect episodes from
        # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy.
        if self.workers.remote_workers():
            self._remote_workers_for_metrics = self.workers.remote_workers(
            )[-len(self.workers.remote_workers()) // 3:]

        num_replay_buffer_shards = self.config["optimizer"][
            "num_replay_buffer_shards"]
        buffer_size = (self.config["replay_buffer_config"]["capacity"] //
                       num_replay_buffer_shards)
        replay_actor_args = [
            num_replay_buffer_shards,
            self.config["learning_starts"],
            buffer_size,
            self.config["train_batch_size"],
            self.config["replay_buffer_config"]["prioritized_replay_alpha"],
            self.config["replay_buffer_config"]["prioritized_replay_beta"],
            self.config["replay_buffer_config"]["prioritized_replay_eps"],
            self.config["multiagent"]["replay_mode"],
            self.config["replay_buffer_config"].get("replay_sequence_length",
                                                    1),
        ]
        # Place all replay buffer shards on the same node as the learner
        # (driver process that runs this execution plan).
        if self.config["replay_buffer_shards_colocated_with_driver"]:
            self.replay_actors = create_colocated_actors(
                actor_specs=[  # (class, args, kwargs={}, count)
                    (ReplayActor, replay_actor_args, {},
                     num_replay_buffer_shards)
                ],
                node=platform.node(),  # localhost
            )[0]  # [0]=only one item in `actor_specs`.
        # Place replay buffer shards on any node(s).
        else:
            self.replay_actors = [
                ReplayActor.remote(*replay_actor_args)
                for _ in range(num_replay_buffer_shards)
            ]
        self.learner_thread = LearnerThread(self.workers.local_worker())
        self.learner_thread.start()
        self.steps_since_update = defaultdict(int)
        weights = self.workers.local_worker().get_weights()
        self.curr_learner_weights = ray.put(weights)
        self.remote_sampling_requests_in_flight: DefaultDict[
            ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
        self.remote_replay_requests_in_flight: DefaultDict[
            ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
        self.curr_num_samples_collected = 0
        self.replay_sample_batches = []
        self._num_ts_trained_since_last_target_update = 0
Esempio n. 2
0
def test_store_to_replay_actor(ray_start_regular_shared):
    actor = ReplayActor.remote(num_shards=1,
                               learning_starts=200,
                               buffer_size=1000,
                               replay_batch_size=100,
                               prioritized_replay_alpha=0.6,
                               prioritized_replay_beta=0.4,
                               prioritized_replay_eps=0.0001)
    assert ray.get(actor.replay.remote()) is None

    workers = make_workers(0)
    a = ParallelRollouts(workers, mode="bulk_sync")
    b = a.for_each(StoreToReplayBuffer(actors=[actor]))

    next(b)
    assert ray.get(actor.replay.remote()) is None  # learning hasn't started
    next(b)
    assert ray.get(actor.replay.remote()).count == 100

    replay_op = Replay(actors=[actor])
    assert next(replay_op).count == 100