Esempio n. 1
0
File: apex.py Progetto: alipay/ray
    def setup(self, config: PartialTrainerConfigDict):
        super().setup(config)

        # Shortcut: If execution_plan, thread and buffer will be created in there.
        if self.config["_disable_execution_plan_api"] is False:
            return

        # Tag those workers (top 1/3rd indices) that we should collect episodes from
        # for metrics due to `PerWorkerEpsilonGreedy` exploration strategy.
        if self.workers.remote_workers():
            self._remote_workers_for_metrics = self.workers.remote_workers(
            )[-len(self.workers.remote_workers()) // 3:]

        num_replay_buffer_shards = self.config["optimizer"][
            "num_replay_buffer_shards"]
        buffer_size = (self.config["replay_buffer_config"]["capacity"] //
                       num_replay_buffer_shards)
        replay_actor_args = [
            num_replay_buffer_shards,
            self.config["learning_starts"],
            buffer_size,
            self.config["train_batch_size"],
            self.config["replay_buffer_config"]["prioritized_replay_alpha"],
            self.config["replay_buffer_config"]["prioritized_replay_beta"],
            self.config["replay_buffer_config"]["prioritized_replay_eps"],
            self.config["multiagent"]["replay_mode"],
            self.config["replay_buffer_config"].get("replay_sequence_length",
                                                    1),
        ]
        # Place all replay buffer shards on the same node as the learner
        # (driver process that runs this execution plan).
        if self.config["replay_buffer_shards_colocated_with_driver"]:
            self.replay_actors = create_colocated_actors(
                actor_specs=[  # (class, args, kwargs={}, count)
                    (ReplayActor, replay_actor_args, {},
                     num_replay_buffer_shards)
                ],
                node=platform.node(),  # localhost
            )[0]  # [0]=only one item in `actor_specs`.
        # Place replay buffer shards on any node(s).
        else:
            self.replay_actors = [
                ReplayActor.remote(*replay_actor_args)
                for _ in range(num_replay_buffer_shards)
            ]
        self.learner_thread = LearnerThread(self.workers.local_worker())
        self.learner_thread.start()
        self.steps_since_update = defaultdict(int)
        weights = self.workers.local_worker().get_weights()
        self.curr_learner_weights = ray.put(weights)
        self.remote_sampling_requests_in_flight: DefaultDict[
            ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
        self.remote_replay_requests_in_flight: DefaultDict[
            ActorHandle, Set[ray.ObjectRef]] = defaultdict(set)
        self.curr_num_samples_collected = 0
        self.replay_sample_batches = []
        self._num_ts_trained_since_last_target_update = 0
Esempio n. 2
0
def test_store_to_replay_actor(ray_start_regular_shared):
    actor = ReplayActor.remote(num_shards=1,
                               learning_starts=200,
                               buffer_size=1000,
                               replay_batch_size=100,
                               prioritized_replay_alpha=0.6,
                               prioritized_replay_beta=0.4,
                               prioritized_replay_eps=0.0001)
    assert ray.get(actor.replay.remote()) is None

    workers = make_workers(0)
    a = ParallelRollouts(workers, mode="bulk_sync")
    b = a.for_each(StoreToReplayBuffer(actors=[actor]))

    next(b)
    assert ray.get(actor.replay.remote()) is None  # learning hasn't started
    next(b)
    assert ray.get(actor.replay.remote()).count == 100

    replay_op = Replay(actors=[actor])
    assert next(replay_op).count == 100
Esempio n. 3
0
File: apex.py Progetto: alipay/ray
    def execution_plan(workers: WorkerSet, config: dict,
                       **kwargs) -> LocalIterator[dict]:
        assert (
            len(kwargs) == 0
        ), "Apex execution_plan does NOT take any additional parameters"

        # Create a number of replay buffer actors.
        num_replay_buffer_shards = config["optimizer"][
            "num_replay_buffer_shards"]
        buffer_size = (config["replay_buffer_config"]["capacity"] //
                       num_replay_buffer_shards)
        replay_actor_args = [
            num_replay_buffer_shards,
            config["learning_starts"],
            buffer_size,
            config["train_batch_size"],
            config["replay_buffer_config"]["prioritized_replay_alpha"],
            config["replay_buffer_config"]["prioritized_replay_beta"],
            config["replay_buffer_config"]["prioritized_replay_eps"],
            config["multiagent"]["replay_mode"],
            config["replay_buffer_config"].get("replay_sequence_length", 1),
        ]
        # Place all replay buffer shards on the same node as the learner
        # (driver process that runs this execution plan).
        if config["replay_buffer_shards_colocated_with_driver"]:
            replay_actors = create_colocated_actors(
                actor_specs=[
                    # (class, args, kwargs={}, count)
                    (ReplayActor, replay_actor_args, {},
                     num_replay_buffer_shards)
                ],
                node=platform.node(),  # localhost
            )[0]  # [0]=only one item in `actor_specs`.
        # Place replay buffer shards on any node(s).
        else:
            replay_actors = [
                ReplayActor(*replay_actor_args)
                for _ in range(num_replay_buffer_shards)
            ]

        # Start the learner thread.
        learner_thread = LearnerThread(workers.local_worker())
        learner_thread.start()

        # Update experience priorities post learning.
        def update_prio_and_stats(
                item: Tuple[ActorHandle, dict, int, int]) -> None:
            actor, prio_dict, env_count, agent_count = item
            if config.get("prioritized_replay"):
                actor.update_priorities.remote(prio_dict)
            metrics = _get_shared_metrics()
            # Manually update the steps trained counter since the learner
            # thread is executing outside the pipeline.
            metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = env_count
            metrics.counters[STEPS_TRAINED_COUNTER] += env_count
            metrics.timers["learner_dequeue"] = learner_thread.queue_timer
            metrics.timers["learner_grad"] = learner_thread.grad_timer
            metrics.timers["learner_overall"] = learner_thread.overall_timer

        # We execute the following steps concurrently:
        # (1) Generate rollouts and store them in one of our replay buffer
        # actors. Update the weights of the worker that generated the batch.
        rollouts = ParallelRollouts(workers, mode="async", num_async=2)
        store_op = rollouts.for_each(StoreToReplayBuffer(actors=replay_actors))
        # Only need to update workers if there are remote workers.
        if workers.remote_workers():
            store_op = store_op.zip_with_source_actor().for_each(
                UpdateWorkerWeights(
                    learner_thread,
                    workers,
                    max_weight_sync_delay=(
                        config["optimizer"]["max_weight_sync_delay"]),
                ))

        # (2) Read experiences from one of the replay buffer actors and send
        # to the learner thread via its in-queue.
        post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
        replay_op = (Replay(
            actors=replay_actors, num_async=4).for_each(lambda x: post_fn(
                x, workers, config)).zip_with_source_actor().for_each(
                    Enqueue(learner_thread.inqueue)))

        # (3) Get priorities back from learner thread and apply them to the
        # replay buffer actors.
        update_op = (Dequeue(learner_thread.outqueue,
                             check=learner_thread.is_alive).for_each(
                                 update_prio_and_stats).for_each(
                                     UpdateTargetNetwork(
                                         workers,
                                         config["target_network_update_freq"],
                                         by_steps_trained=True)))

        if config["training_intensity"]:
            # Execute (1), (2) with a fixed intensity ratio.
            rr_weights = calculate_rr_weights(config) + ["*"]
            merged_op = Concurrently(
                [store_op, replay_op, update_op],
                mode="round_robin",
                output_indexes=[2],
                round_robin_weights=rr_weights,
            )
        else:
            # Execute (1), (2), (3) asynchronously as fast as possible. Only
            # output items from (3) since metrics aren't available before
            # then.
            merged_op = Concurrently([store_op, replay_op, update_op],
                                     mode="async",
                                     output_indexes=[2])

        # Add in extra replay and learner metrics to the training result.
        def add_apex_metrics(result: dict) -> dict:
            replay_stats = ray.get(replay_actors[0].stats.remote(
                config["optimizer"].get("debug")))
            exploration_infos = workers.foreach_policy_to_train(
                lambda p, _: p.get_exploration_state())
            result["info"].update({
                "exploration_infos":
                exploration_infos,
                "learner_queue":
                learner_thread.learner_queue_size.stats(),
                LEARNER_INFO:
                copy.deepcopy(learner_thread.learner_info),
                "replay_shard_0":
                replay_stats,
            })
            return result

        # Only report metrics from the workers with the lowest 1/3 of
        # epsilons.
        selected_workers = workers.remote_workers(
        )[-len(workers.remote_workers()) // 3:]

        return StandardMetricsReporting(
            merged_op, workers, config,
            selected_workers=selected_workers).for_each(add_apex_metrics)