def setup(self, config: PartialTrainerConfigDict): super().setup(config) # Shortcut: If execution_plan, thread and buffer will be created in there. if self.config["_disable_execution_plan_api"] is False: return # Tag those workers (top 1/3rd indices) that we should collect episodes from # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy. if self.workers.remote_workers(): self._remote_workers_for_metrics = self.workers.remote_workers( )[-len(self.workers.remote_workers()) // 3:] num_replay_buffer_shards = self.config["optimizer"][ "num_replay_buffer_shards"] buffer_size = (self.config["replay_buffer_config"]["capacity"] // num_replay_buffer_shards) replay_actor_args = [ num_replay_buffer_shards, self.config["learning_starts"], buffer_size, self.config["train_batch_size"], self.config["replay_buffer_config"]["prioritized_replay_alpha"], self.config["replay_buffer_config"]["prioritized_replay_beta"], self.config["replay_buffer_config"]["prioritized_replay_eps"], self.config["multiagent"]["replay_mode"], self.config["replay_buffer_config"].get("replay_sequence_length", 1), ] # Place all replay buffer shards on the same node as the learner # (driver process that runs this execution plan). if self.config["replay_buffer_shards_colocated_with_driver"]: self.replay_actors = create_colocated_actors( actor_specs=[ # (class, args, kwargs={}, count) (ReplayActor, replay_actor_args, {}, num_replay_buffer_shards) ], node=platform.node(), # localhost )[0] # [0]=only one item in `actor_specs`. # Place replay buffer shards on any node(s). else: self.replay_actors = [ ReplayActor.remote(*replay_actor_args) for _ in range(num_replay_buffer_shards) ] self.learner_thread = LearnerThread(self.workers.local_worker()) self.learner_thread.start() self.steps_since_update = defaultdict(int) weights = self.workers.local_worker().get_weights() self.curr_learner_weights = ray.put(weights) self.remote_sampling_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.remote_replay_requests_in_flight: DefaultDict[ ActorHandle, Set[ray.ObjectRef]] = defaultdict(set) self.curr_num_samples_collected = 0 self.replay_sample_batches = [] self._num_ts_trained_since_last_target_update = 0
def test_store_to_replay_actor(ray_start_regular_shared): actor = ReplayActor.remote(num_shards=1, learning_starts=200, buffer_size=1000, replay_batch_size=100, prioritized_replay_alpha=0.6, prioritized_replay_beta=0.4, prioritized_replay_eps=0.0001) assert ray.get(actor.replay.remote()) is None workers = make_workers(0) a = ParallelRollouts(workers, mode="bulk_sync") b = a.for_each(StoreToReplayBuffer(actors=[actor])) next(b) assert ray.get(actor.replay.remote()) is None # learning hasn't started next(b) assert ray.get(actor.replay.remote()).count == 100 replay_op = Replay(actors=[actor]) assert next(replay_op).count == 100