예제 #1
0
    def training_iteration(self) -> ResultDict:
        # Shortcut.
        first_worker = self.workers.remote_workers()[0]

        # Run sampling and update steps on each worker in asynchronous fashion.
        sample_and_update_results = asynchronous_parallel_requests(
            remote_requests_in_flight=self.remote_requests_in_flight,
            actors=self.workers.remote_workers(),
            ray_wait_timeout_s=0.0,
            max_remote_requests_in_flight_per_actor=1,  # 2
            remote_fn=self._sample_and_train_torch_distributed,
        )

        # For all results collected:
        # - Update our counters and timers.
        # - Update the worker's global_vars.
        # - Build info dict using a LearnerInfoBuilder object.
        learner_info_builder = LearnerInfoBuilder(num_devices=1)
        steps_this_iter = 0
        for worker, results in sample_and_update_results.items():
            for result in results:
                steps_this_iter += result["env_steps"]
                self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
                self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
                self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
                self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]
                self._timers[LEARN_ON_BATCH_TIMER].push(result["learn_on_batch_time"])
                self._timers[SAMPLE_TIMER].push(result["sample_time"])
            # Add partial learner info to builder object.
            learner_info_builder.add_learn_on_batch_results_multi_agent(result["info"])

            # Broadcast the local set of global vars.
            global_vars = {"timestep": self._counters[NUM_AGENT_STEPS_SAMPLED]}
            for worker in self.workers.remote_workers():
                worker.set_global_vars.remote(global_vars)

        self._counters[STEPS_TRAINED_THIS_ITER_COUNTER] = steps_this_iter

        # Sync down the weights from 1st remote worker (only if we have received
        # some results from it).
        # As with the sync up, this is not really needed unless the user is
        # reading the local weights.
        if (
            self.config["keep_local_weights_in_sync"]
            and first_worker in sample_and_update_results
        ):
            self.workers.local_worker().set_weights(
                ray.get(first_worker.get_weights.remote())
            )
        # Return merged laarner into results.
        new_learner_info = learner_info_builder.finalize()
        if new_learner_info:
            self._curr_learner_info = new_learner_info
        return self._curr_learner_info
예제 #2
0
파일: apex.py 프로젝트: alipay/ray
    def get_samples_and_store_to_replay_buffers(self):
        # in the case the num_workers = 0
        if not self.workers.remote_workers():
            with self._timers[SAMPLE_TIMER]:
                local_sampling_worker = self.workers.local_worker()
                batch = local_sampling_worker.sample()
                actor = random.choice(self.replay_actors)
                ray.get(actor.add_batch.remote(batch))
                batch_statistics = {
                    local_sampling_worker: [{
                        "agent_steps": batch.agent_steps(),
                        "env_steps": batch.env_steps(),
                    }]
                }
                return batch_statistics

        def remote_worker_sample_and_store(worker: RolloutWorker,
                                           replay_actors: List[ReplayActor]):
            # This function is run as a remote function on sampling workers,
            # and should only be used with the RolloutWorker's apply function ever.
            # It is used to gather samples, and trigger the operation to store them to
            # replay actors from the rollout worker instead of returning the obj
            # refs for the samples to the driver process and doing the sampling
            # operation on there.
            _batch = worker.sample()
            _actor = random.choice(replay_actors)
            _actor.add_batch.remote(_batch)
            _batch_statistics = {
                "agent_steps": _batch.agent_steps(),
                "env_steps": _batch.env_steps(),
            }
            return _batch_statistics

        # Sample and Store in the Replay Actors on the sampling workers.
        with self._timers[SAMPLE_TIMER]:
            # Results are a mapping from ActorHandle (RolloutWorker) to their
            # returned gradient calculation results.
            num_samples_ready_dict: Dict[
                ActorHandle, T] = asynchronous_parallel_requests(
                    remote_requests_in_flight=self.
                    remote_sampling_requests_in_flight,
                    actors=self.workers.remote_workers(),
                    ray_wait_timeout_s=0.1,
                    max_remote_requests_in_flight_per_actor=4,
                    remote_fn=remote_worker_sample_and_store,
                    remote_kwargs=[{
                        "replay_actors": self.replay_actors
                    }] * len(self.workers.remote_workers()),
                )
        return num_samples_ready_dict
예제 #3
0
 def get_samples_from_workers(self) -> Dict[ActorHandle, List[SampleBatch]]:
     # Perform asynchronous sampling on all (remote) rollout workers.
     if self.workers.remote_workers():
         sample_batches: Dict[
             ActorHandle, List[ObjectRef]] = asynchronous_parallel_requests(
                 remote_requests_in_flight=self.remote_requests_in_flight,
                 actors=self.workers.remote_workers(),
                 ray_wait_timeout_s=self.config["sample_wait_timeout"],
                 max_remote_requests_in_flight_per_actor=self.
                 config["max_sample_requests_in_flight_per_worker"],
                 return_result_obj_ref_ids=True,
             )
     else:
         # only sampling on the local worker
         sample_batches = {
             self.workers.local_worker():
             [self.workers.local_worker().sample()]
         }
     return sample_batches
예제 #4
0
    def process_experiences_tree_aggregation(
        self, actor_to_sample_batches_refs: Dict[ActorHandle, List[ObjectRef]]
    ) -> Union[SampleBatchType, None]:
        batches = [
            sample_batch_ref
            for refs_batch in actor_to_sample_batches_refs.values()
            for sample_batch_ref in refs_batch
        ]
        ready_processed_batches = []
        for batch in batches:
            aggregator = random.choice(self.aggregator_workers)
            processed_sample_batches: Dict[
                ActorHandle, List[ObjectRef]] = asynchronous_parallel_requests(
                    remote_requests_in_flight=self.
                    remote_aggregator_requests_in_flight,
                    actors=[aggregator],
                    remote_fn=lambda actor, b: actor.process_episodes(b),
                    remote_kwargs=[{
                        "b": batch
                    }],
                    ray_wait_timeout_s=self.config["aggregator_wait_timeout"],
                    max_remote_requests_in_flight_per_actor=float("inf"),
                )
            for ready_sub_batches in processed_sample_batches.values():
                ready_processed_batches.extend(ready_sub_batches)

        waiting_processed_sample_batches: Dict[
            ActorHandle, List[ObjectRef]] = wait_asynchronous_requests(
                remote_requests_in_flight=self.
                remote_aggregator_requests_in_flight,
                ray_wait_timeout_s=self.config["aggregator_wait_timeout"],
            )
        for ready_sub_batches in waiting_processed_sample_batches.values():
            ready_processed_batches.extend(ready_sub_batches)

        return ready_processed_batches
예제 #5
0
파일: alpha_star.py 프로젝트: smorad/ray
    def training_iteration(self) -> ResultDict:
        # Trigger asynchronous rollouts on all RolloutWorkers.
        # - Rollout results are sent directly to correct replay buffer
        #   shards, instead of here (to the driver).
        with self._timers[SAMPLE_TIMER]:
            sample_results = asynchronous_parallel_requests(
                remote_requests_in_flight=self.remote_requests_in_flight,
                actors=self.workers.remote_workers()
                or [self.workers.local_worker()],
                ray_wait_timeout_s=self.config["sample_wait_timeout"],
                max_remote_requests_in_flight_per_actor=2,
                remote_fn=self._sample_and_send_to_buffer,
            )
        # Update sample counters.
        for sample_result in sample_results.values():
            for (env_steps, agent_steps) in sample_result:
                self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps
                self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps

        # Trigger asynchronous training update requests on all learning
        # policies.
        with self._timers[LEARN_ON_BATCH_TIMER]:
            pol_actors = []
            args = []
            for pid, pol_actor, repl_actor in self.distributed_learners:
                pol_actors.append(pol_actor)
                args.append([repl_actor, pid])
            train_results = asynchronous_parallel_requests(
                remote_requests_in_flight=self.remote_requests_in_flight,
                actors=pol_actors,
                ray_wait_timeout_s=self.config["learn_wait_timeout"],
                max_remote_requests_in_flight_per_actor=2,
                remote_fn=self._update_policy,
                remote_args=args,
            )

        # Update sample counters.
        for train_result in train_results.values():
            for result in train_result:
                if NUM_AGENT_STEPS_TRAINED in result:
                    self._counters[NUM_AGENT_STEPS_TRAINED] += result[
                        NUM_AGENT_STEPS_TRAINED]

        # For those policies that have been updated in this iteration
        # (not all policies may have undergone an updated as we are
        # requesting updates asynchronously):
        # - Gather train infos.
        # - Update weights to those remote rollout workers that contain
        #   the respective policy.
        with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
            train_infos = {}
            policy_weights = {}
            for pol_actor, policy_results in train_results.items():
                results_have_same_structure = True
                for result1, result2 in zip(policy_results,
                                            policy_results[1:]):
                    try:
                        tree.assert_same_structure(result1, result2)
                    except (ValueError, TypeError):
                        results_have_same_structure = False
                        break
                if len(policy_results) > 1 and results_have_same_structure:
                    policy_result = tree.map_structure(
                        lambda *_args: sum(_args) / len(policy_results),
                        *policy_results)
                else:
                    policy_result = policy_results[-1]
                if policy_result:
                    pid = self.distributed_learners.get_policy_id(pol_actor)
                    train_infos[pid] = policy_result
                    policy_weights[pid] = pol_actor.get_weights.remote()

            policy_weights_ref = ray.put(policy_weights)

            global_vars = {
                "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
                "league_builder": self.league_builder.__getstate__(),
            }

            for worker in self.workers.remote_workers():
                worker.set_weights.remote(policy_weights_ref, global_vars)

        return train_infos
예제 #6
0
파일: a3c.py 프로젝트: wuisawesome/ray
    def training_iteration(self) -> ResultDict:
        # Shortcut.
        local_worker = self.workers.local_worker()

        # Define the function executed in parallel by all RolloutWorkers to collect
        # samples + compute and return gradients (and other information).

        def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]:
            """Call sample() and compute_gradients() remotely on workers."""
            samples = worker.sample()
            grads, infos = worker.compute_gradients(samples)
            return {
                "grads": grads,
                "infos": infos,
                "agent_steps": samples.agent_steps(),
                "env_steps": samples.env_steps(),
            }

        # Perform rollouts and gradient calculations asynchronously.
        with self._timers[GRAD_WAIT_TIMER]:
            # Results are a mapping from ActorHandle (RolloutWorker) to their
            # returned gradient calculation results.
            async_results: Dict[
                ActorHandle, Dict] = asynchronous_parallel_requests(
                    remote_requests_in_flight=self.remote_requests_in_flight,
                    actors=self.workers.remote_workers(),
                    ray_wait_timeout_s=0.0,
                    max_remote_requests_in_flight_per_actor=1,
                    remote_fn=sample_and_compute_grads,
                )

        # Loop through all fetched worker-computed gradients (if any)
        # and apply them - one by one - to the local worker's model.
        # After each apply step (one step per worker that returned some gradients),
        # update that particular worker's weights.
        global_vars = None
        learner_info_builder = LearnerInfoBuilder(num_devices=1)
        for worker, result in async_results.items():
            # Apply gradients to local worker.
            with self._timers[APPLY_GRADS_TIMER]:
                local_worker.apply_gradients(result["grads"])
            self._timers[APPLY_GRADS_TIMER].push_units_processed(
                result["agent_steps"])

            # Update all step counters.
            self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"]
            self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"]
            self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"]
            self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"]

            # Create current global vars.
            global_vars = {
                "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
            }

            # Synch updated weights back to the particular worker.
            with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
                weights = local_worker.get_weights(
                    local_worker.get_policies_to_train())
                worker.set_weights.remote(weights, global_vars)

            learner_info_builder.add_learn_on_batch_results_multi_agent(
                result["infos"])

        # Update global vars of the local worker.
        if global_vars:
            local_worker.set_global_vars(global_vars)

        return learner_info_builder.finalize()
예제 #7
0
    def training_iteration(self) -> ResultDict:
        # Trigger asynchronous rollouts on all RolloutWorkers.
        # - Rollout results are sent directly to correct replay buffer
        #   shards, instead of here (to the driver).
        with self._timers[SAMPLE_TIMER]:
            sample_results = asynchronous_parallel_requests(
                remote_requests_in_flight=self.remote_requests_in_flight,
                actors=self.workers.remote_workers()
                or [self.workers.local_worker()],
                ray_wait_timeout_s=0.01,
                max_remote_requests_in_flight_per_actor=2,
                remote_fn=self._sample_and_send_to_buffer,
            )
        # Update sample counters.
        for (env_steps, agent_steps) in sample_results.values():
            self._counters[NUM_ENV_STEPS_SAMPLED] += env_steps
            self._counters[NUM_AGENT_STEPS_SAMPLED] += agent_steps

        # Trigger asynchronous training update requests on all learning
        # policies.
        with self._timers[LEARN_ON_BATCH_TIMER]:
            pol_actors = []
            args = []
            for pid, pol_actor, repl_actor in self.distributed_learners:
                pol_actors.append(pol_actor)
                args.append([repl_actor, pid])
            train_results = asynchronous_parallel_requests(
                remote_requests_in_flight=self.remote_requests_in_flight,
                actors=pol_actors,
                ray_wait_timeout_s=0.1,
                max_remote_requests_in_flight_per_actor=2,
                remote_fn=self._update_policy,
                remote_args=args,
            )

        # Update sample counters.
        for result in train_results.values():
            if NUM_AGENT_STEPS_TRAINED in result:
                self._counters[NUM_AGENT_STEPS_TRAINED] += result[
                    NUM_AGENT_STEPS_TRAINED]

        # For those policies that have been updated in this iteration
        # (not all policies may have undergone an updated as we are
        # requesting updates asynchronously):
        # - Gather train infos.
        # - Update weights to those remote rollout workers that contain
        #   the respective policy.
        with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
            train_infos = {}
            policy_weights = {}
            for pol_actor, policy_result in train_results.items():
                if policy_result:
                    pid = self.distributed_learners.get_policy_id(pol_actor)
                    train_infos[pid] = policy_result
                    policy_weights[pid] = pol_actor.get_weights.remote()

            policy_weights_ref = ray.put(policy_weights)

            global_vars = {
                "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
                "win_rates": self.win_rates,
            }

            for worker in self.workers.remote_workers():
                worker.set_weights.remote(policy_weights_ref, global_vars)

        return train_infos
예제 #8
0
파일: apex.py 프로젝트: alipay/ray
    def sample_from_replay_buffer_place_on_learner_queue_non_blocking(
            self, num_samples_collected: Dict[ActorHandle, int]) -> None:
        """Get samples from the replay buffer and place them on the learner queue.

        Args:
            num_samples_collected: A mapping from ActorHandle (RolloutWorker) to
                number of samples returned by the remote worker. This is used to
                implement training intensity which is the concept of triggering a
                certain amount of training based on the number of samples that have
                been collected since the last time that training was triggered.

        """
        def wait_on_replay_actors(timeout: float) -> None:
            """Wait for the replay actors to finish sampling for timeout seconds.
            If the timeout is None, then block on the actors indefinitely.
            """
            replay_samples_ready: Dict[ActorHandle,
                                       T] = wait_asynchronous_requests(
                                           remote_requests_in_flight=self.
                                           remote_replay_requests_in_flight,
                                           ray_wait_timeout_s=timeout,
                                       )

            for replay_actor, sample_batches in replay_samples_ready.items():
                for sample_batch in sample_batches:
                    self.replay_sample_batches.append(
                        (replay_actor, sample_batch))

        num_samples_collected = sum(num_samples_collected.values())
        self.curr_num_samples_collected += num_samples_collected
        if self.curr_num_samples_collected >= self.config["train_batch_size"]:
            wait_on_replay_actors(None)
            training_intensity = int(self.config["training_intensity"] or 1)
            num_requests_to_launch = (
                self.curr_num_samples_collected /
                self.config["train_batch_size"]) * training_intensity
            num_requests_to_launch = max(1, round(num_requests_to_launch))
            self.curr_num_samples_collected = 0
            for _ in range(num_requests_to_launch):
                rand_actor = random.choice(self.replay_actors)
                replay_samples_ready: Dict[
                    ActorHandle, T] = asynchronous_parallel_requests(
                        remote_requests_in_flight=self.
                        remote_replay_requests_in_flight,
                        actors=[rand_actor],
                        ray_wait_timeout_s=0.1,
                        max_remote_requests_in_flight_per_actor=
                        num_requests_to_launch,
                        remote_fn=lambda actor: actor.replay(),
                    )
            for replay_actor, sample_batches in replay_samples_ready.items():
                for sample_batch in sample_batches:
                    self.replay_sample_batches.append(
                        (replay_actor, sample_batch))

        wait_on_replay_actors(0.1)

        # add the sample batches to the learner queue
        while self.replay_sample_batches:
            try:
                item = self.replay_sample_batches[0]
                # the replay buffer returns none if it has not been filled to
                # the minimum threshold yet.
                if item:
                    self.learner_thread.inqueue.put(
                        self.replay_sample_batches[0], timeout=0.001)
                    self.replay_sample_batches.pop(0)
            except queue.Full:
                break