Beispiel #1
0
class MyTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        # Run this Trainer with new `training_iteration` API and set some PPO-specific
        # parameters.
        return with_common_config({
            "_disable_execution_plan_api": True,
            "num_sgd_iter": 10,
            "sgd_minibatch_size": 128,
        })

    @override(Trainer)
    def setup(self, config):
        # Call super's `setup` to create rollout workers.
        super().setup(config)
        # Create local replay buffer.
        self.local_replay_buffer = MultiAgentReplayBuffer(num_shards=1,
                                                          learning_starts=1000,
                                                          capacity=50000,
                                                          replay_batch_size=64)

    @override(Trainer)
    def training_iteration(self) -> ResultDict:
        # Generate common experiences, collect batch for PPO, store every (DQN) batch
        # into replay buffer.
        ppo_batches = []
        num_env_steps = 0
        # PPO batch size fixed at 200.
        while num_env_steps < 200:
            ma_batches = synchronous_parallel_sample(self.workers)
            # Loop through (parallely collected) ma-batches.
            for ma_batch in ma_batches:
                # Update sampled counters.
                self._counters[NUM_ENV_STEPS_SAMPLED] += ma_batch.count
                self._counters[
                    NUM_AGENT_STEPS_SAMPLED] += ma_batch.agent_steps()
                ppo_batch = ma_batch.policy_batches.pop("ppo_policy")
                # Add collected batches (only for DQN policy) to replay buffer.
                self.local_replay_buffer.add_batch(ma_batch)

                ppo_batches.append(ppo_batch)
                num_env_steps += ppo_batch.count

        # DQN sub-flow.
        dqn_train_results = {}
        dqn_train_batch = self.local_replay_buffer.replay()
        if dqn_train_batch is not None:
            dqn_train_results = train_one_step(self, dqn_train_batch,
                                               ["dqn_policy"])
            self._counters[
                "agent_steps_trained_DQN"] += dqn_train_batch.agent_steps()
            print(
                "DQN policy learning on samples from",
                "agent steps trained",
                dqn_train_batch.agent_steps(),
            )
        # Update DQN's target net every 500 train steps.
        if (self._counters["agent_steps_trained_DQN"] -
                self._counters[LAST_TARGET_UPDATE_TS] >= 500):
            self.workers.local_worker().get_policy(
                "dqn_policy").update_target()
            self._counters[NUM_TARGET_UPDATES] += 1
            self._counters[LAST_TARGET_UPDATE_TS] = self._counters[
                "agent_steps_trained_DQN"]

        # PPO sub-flow.
        ppo_train_batch = SampleBatch.concat_samples(ppo_batches)
        self._counters[
            "agent_steps_trained_PPO"] += ppo_train_batch.agent_steps()
        # Standardize advantages.
        ppo_train_batch[Postprocessing.ADVANTAGES] = standardized(
            ppo_train_batch[Postprocessing.ADVANTAGES])
        print(
            "PPO policy learning on samples from",
            "agent steps trained",
            ppo_train_batch.agent_steps(),
        )
        ppo_train_batch = MultiAgentBatch({"ppo_policy": ppo_train_batch},
                                          ppo_train_batch.count)
        ppo_train_results = train_one_step(self, ppo_train_batch,
                                           ["ppo_policy"])

        # Combine results for PPO and DQN into one results dict.
        results = dict(ppo_train_results, **dqn_train_results)
        return results
Beispiel #2
0
class MARWILTrainer(Trainer):
    @classmethod
    @override(Trainer)
    def get_default_config(cls) -> TrainerConfigDict:
        return DEFAULT_CONFIG

    @override(Trainer)
    def validate_config(self, config: TrainerConfigDict) -> None:
        # Call super's validation method.
        super().validate_config(config)

        if config["num_gpus"] > 1:
            raise ValueError("`num_gpus` > 1 not yet supported for MARWIL!")

        if config["postprocess_inputs"] is False and config["beta"] > 0.0:
            raise ValueError(
                "`postprocess_inputs` must be True for MARWIL (to "
                "calculate accum., discounted returns)!")

    @override(Trainer)
    def get_default_policy_class(self,
                                 config: TrainerConfigDict) -> Type[Policy]:
        if config["framework"] == "torch":
            from ray.rllib.agents.marwil.marwil_torch_policy import MARWILTorchPolicy

            return MARWILTorchPolicy
        else:
            return MARWILTFPolicy

    @override(Trainer)
    def setup(self, config: PartialTrainerConfigDict):
        super().setup(config)
        # `training_iteration` implementation: Setup buffer in `setup`, not
        # in `execution_plan` (deprecated).
        if self.config["_disable_execution_plan_api"] is True:
            self.local_replay_buffer = MultiAgentReplayBuffer(
                learning_starts=self.config["learning_starts"],
                capacity=self.config["replay_buffer_size"],
                replay_batch_size=self.config["train_batch_size"],
                replay_sequence_length=1,
            )

    @override(Trainer)
    def training_iteration(self) -> ResultDict:
        # Collect SampleBatches from sample workers.
        batch = synchronous_parallel_sample(worker_set=self.workers)
        batch = batch.as_multi_agent()
        self._counters[NUM_AGENT_STEPS_SAMPLED] += batch.agent_steps()
        self._counters[NUM_ENV_STEPS_SAMPLED] += batch.env_steps()
        # Add batch to replay buffer.
        self.local_replay_buffer.add_batch(batch)

        # Pull batch from replay buffer and train on it.
        train_batch = self.local_replay_buffer.replay()
        # Train.
        if self.config["simple_optimizer"]:
            train_results = train_one_step(self, train_batch)
        else:
            train_results = multi_gpu_train_one_step(self, train_batch)
        self._counters[NUM_AGENT_STEPS_TRAINED] += batch.agent_steps()
        self._counters[NUM_ENV_STEPS_TRAINED] += batch.env_steps()

        global_vars = {
            "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED],
        }

        # Update weights - after learning on the local worker - on all remote
        # workers.
        if self.workers.remote_workers():
            with self._timers[WORKER_UPDATE_TIMER]:
                self.workers.sync_weights(global_vars=global_vars)

        # Update global vars on local worker as well.
        self.workers.local_worker().set_global_vars(global_vars)

        return train_results

    @staticmethod
    @override(Trainer)
    def execution_plan(workers: WorkerSet, config: TrainerConfigDict,
                       **kwargs) -> LocalIterator[dict]:
        assert (
            len(kwargs) == 0
        ), "Marwill execution_plan does NOT take any additional parameters"

        rollouts = ParallelRollouts(workers, mode="bulk_sync")
        replay_buffer = MultiAgentReplayBuffer(
            learning_starts=config["learning_starts"],
            capacity=config["replay_buffer_size"],
            replay_batch_size=config["train_batch_size"],
            replay_sequence_length=1,
        )

        store_op = rollouts.for_each(
            StoreToReplayBuffer(local_buffer=replay_buffer))

        replay_op = (Replay(local_buffer=replay_buffer).combine(
            ConcatBatches(
                min_batch_size=config["train_batch_size"],
                count_steps_by=config["multiagent"]["count_steps_by"],
            )).for_each(TrainOneStep(workers)))

        train_op = Concurrently([store_op, replay_op],
                                mode="round_robin",
                                output_indexes=[1])

        return StandardMetricsReporting(train_op, workers, config)