Exemple #1
0
class SVGInfTrainer(Trainer):
    """Single agent trainer for SVG(inf)."""

    # pylint: disable=attribute-defined-outside-init

    _name = "SVG(inf)"
    _default_config = DEFAULT_CONFIG
    _policy = SVGInfTorchPolicy

    @override(Trainer)
    def _init(self, config, env_creator):
        self._validate_config(config)
        self.workers = self._make_workers(env_creator,
                                          self._policy,
                                          config,
                                          num_workers=config["num_workers"])
        # Dummy optimizer to log stats since Trainer.collect_metrics is coupled with it
        self.optimizer = PolicyOptimizer(self.workers)

        policy = self.get_policy()
        policy.set_reward_from_config(config["env"], config["env_config"])

        self.replay = NumpyReplayBuffer(policy.observation_space,
                                        policy.action_space,
                                        config["buffer_size"])
        self.replay.add_fields(ReplayField(SampleBatch.ACTION_LOGP))
        self.replay.seed(config["seed"])

    @override(Trainer)
    def _train(self):
        worker = self.workers.local_worker()
        policy = worker.get_policy()

        samples = worker.sample()
        self.optimizer.num_steps_sampled += samples.count
        for row in samples.rows():
            self.replay.add(row)
        stats = policy.get_exploration_info()

        with policy.learning_off_policy():
            for _ in range(int(samples.count *
                               self.config["updates_per_step"])):
                batch = self.replay.sample(self.config["train_batch_size"])
                off_policy_stats = policy.learn_on_batch(batch)
                self.optimizer.num_steps_trained += batch.count
        stats.update(off_policy_stats)

        stats.update(policy.learn_on_batch(samples))

        return self._log_metrics(stats)
Exemple #2
0
class OffPolicyMixin(ABC):
    """Adds a replay buffer and standard procedures for `learn_on_batch`."""

    replay: NumpyReplayBuffer

    def build_replay_buffer(self):
        """Construct the experience replay buffer.

        Should be called by subclasses on init.
        """
        self.replay = NumpyReplayBuffer(self.observation_space,
                                        self.action_space,
                                        self.config["buffer_size"])
        self.replay.seed(self.config["seed"])

    @learner_stats
    def learn_on_batch(self, samples: SampleBatch):
        """Run one logical iteration of training.

        Returns:
            An info dict from this iteration.
        """
        self.add_to_buffer(samples)

        info = {}
        info.update(self.get_exploration_info())

        for _ in range(int(self.config["improvement_steps"])):
            batch = self.replay.sample(self.config["batch_size"])
            batch = self.lazy_tensor_dict(batch)
            info.update(self.improve_policy(batch))

        return info

    def add_to_buffer(self, samples: SampleBatch):
        """Add sample batch to replay buffer"""
        self.replay.add(samples)
        if self.config["std_obs"]:
            self.replay.update_obs_stats()

    @abstractmethod
    def improve_policy(self, batch: TensorDict) -> dict:
        """Run one step of Policy Improvement."""

    @staticmethod
    def add_options(policy_cls: type) -> type:
        """Decorator to add default off-policy options used by OffPolicyMixin."""
        return off_policy_options(policy_cls)
Exemple #3
0
def numpy_replay(obs_space, action_space, size, sample_batch):
    replay = NumpyReplayBuffer(obs_space, action_space, size)
    replay.add(sample_batch)
    return replay
Exemple #4
0
def replay(obs_space, action_space, samples):
    replay = NumpyReplayBuffer(obs_space, action_space, size=samples.count)
    replay.add(samples)
    return replay
Exemple #5
0
class ModelBasedTrainer(OffPolicyTrainer):
    """Generic trainer for model-based agents."""

    # pylint: disable=attribute-defined-outside-init

    @override(OffPolicyTrainer)
    def _init(self, config, env_creator):
        super()._init(config, env_creator)
        policy = self.get_policy()
        policy.set_reward_from_config(config["env"], config["env_config"])
        policy.set_termination_from_config(config["env"], config["env_config"])

    @staticmethod
    @override(OffPolicyTrainer)
    def validate_config(config):
        OffPolicyTrainer.validate_config(config)
        assert (config["holdout_ratio"] <
                1.0), "Holdout data cannot be the entire dataset"
        assert (config["max_holdout"] >=
                0), "Maximum number of holdout samples must be non-negative"
        assert (config["policy_improvements"] >=
                0), "Number of policy improvement steps must be non-negative"
        assert (
            0 <= config["real_data_ratio"] <= 1
        ), "Fraction of real data samples for policy improvement must be in [0, 1]"
        assert (config["virtual_buffer_size"] >=
                0), "Virtual buffer capacity must be non-negative"
        assert (config["model_rollouts"] >=
                0), "Cannot sample a negative number of model rollouts"

    @override(OffPolicyTrainer)
    def build_replay_buffer(self, config):
        super().build_replay_buffer(config)
        policy = self.get_policy()
        self.virtual_replay = NumpyReplayBuffer(policy.observation_space,
                                                policy.action_space,
                                                config["virtual_buffer_size"])
        self.virtual_replay.seed(config["seed"])

    @override(OffPolicyTrainer)
    def _train(self):
        start_samples = self.sample_until_learning_starts()

        config = self.config
        worker = self.workers.local_worker()
        policy = worker.get_policy()
        stats = {}
        while not self._iteration_done():
            samples = worker.sample()
            self.tracker.num_steps_sampled += samples.count
            for row in samples.rows():
                self.replay.add(row)

            eval_losses, model_train_info = self.train_dynamics_model()
            policy.setup_sampling_models(eval_losses)
            self.populate_virtual_buffer(config["model_rollouts"] *
                                         samples.count)
            policy_train_info = self.improve_policy(
                config["policy_improvements"] * samples.count)

            stats.update(model_train_info)
            stats.update(policy_train_info)

        self.tracker.num_steps_sampled += start_samples
        return self._log_metrics(stats)

    def train_dynamics_model(self) -> Tuple[List[float], Dict[str, float]]:
        """Implements the model training step.

        Calls the policy to optimize the model on the environment replay buffer.

        Returns:
            A tuple containing the list of evaluation losses for each model and
            a dictionary of training statistics
        """
        samples = self.replay.all_samples()
        samples.shuffle()
        holdout = min(
            int(len(self.replay) * self.config["holdout_ratio"]),
            self.config["max_holdout"],
        )
        train_data, eval_data = samples.slice(holdout,
                                              None), samples.slice(0, holdout)

        policy = self.get_policy()
        eval_losses, stats = policy.optimize_model(train_data, eval_data)

        return eval_losses, stats

    def populate_virtual_buffer(self, num_rollouts: int):
        """Add model rollouts branched from real data to the virtual pool.

        Args:
            num_rollouts: Number of initial states to samples from the
                environment replay buffer
        """
        if not (num_rollouts and self.config["real_data_ratio"] < 1.0):
            return

        real_samples = self.replay.sample(num_rollouts)
        policy = self.get_policy()
        virtual_samples = policy.generate_virtual_sample_batch(real_samples)
        for row in virtual_samples.rows():
            self.virtual_replay.add(row)

    def improve_policy(self, num_improvements: int) -> Dict[str, float]:
        """Call the policy to perform policy improvement using the augmented replay.

        Args:
            num_improvements: Number of times to call `policy.learn_on_batch`

        Returns:
            A dictionary of training and exploration statistics
        """
        policy = self.get_policy()
        batch_size = self.config["train_batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        stats = {}
        for _ in range(num_improvements):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            stats = get_learner_stats(policy.learn_on_batch(batch))
            self.tracker.num_steps_trained += batch.count

        stats.update(policy.get_exploration_info())
        return stats
Exemple #6
0
class MBPOTorchPolicy(MBPolicyMixin, EnvFnMixin, ModelSamplingMixin, SACTorchPolicy):
    """Model-Based Policy Optimization policy in PyTorch to use with RLlib."""

    # pylint:disable=too-many-ancestors
    virtual_replay: NumpyReplayBuffer
    model_trainer: LightningModelTrainer
    dist_class = WrapStochasticPolicy

    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)
        models = self.module.models
        self.loss_model = MaximumLikelihood(models)

        self.build_timers()
        self.model_trainer = LightningModelTrainer(
            models=self.module.models,
            loss_fn=self.loss_model,
            optimizer=self.optimizers["models"],
            replay=self.replay,
            config=self.config,
        )

    def _make_optimizers(self):
        optimizers = super()._make_optimizers()
        config = self.config["optimizer"]
        optimizers["models"] = build_optimizer(self.module.models, config["models"])
        return optimizers

    def build_replay_buffer(self):
        super().build_replay_buffer()
        self.virtual_replay = NumpyReplayBuffer(
            self.observation_space,
            self.action_space,
            self.config["virtual_buffer_size"],
        )
        self.virtual_replay.seed(self.config["seed"])

    def build_timers(self):
        super().build_timers()
        self.timers["augmentation"] = TimerStat()

    @learner_stats
    def learn_on_batch(self, samples: SampleBatch) -> dict:
        self.add_to_buffer(samples)
        self._learn_calls += 1

        info = {}
        warmup = self._learn_calls == 1
        if self._learn_calls % self.config["model_update_interval"] == 0 or warmup:
            with self.timers["model"] as timer:
                losses, model_info = self.train_dynamics_model(warmup=warmup)
                timer.push_units_processed(model_info["model_epochs"])
                info.update(model_info)
            self.set_new_elite(losses)

        with self.timers["augmentation"] as timer:
            count_before = len(self.virtual_replay)
            self.populate_virtual_buffer()
            timer.push_units_processed(len(self.virtual_replay) - count_before)

        with self.timers["policy"] as timer:
            times = self.config["improvement_steps"]
            policy_info = self.update_policy(times=times)
            timer.push_units_processed(times)
            info.update(policy_info)

        info.update(self.timer_stats())
        return info

    def train_dynamics_model(
        self, warmup: bool = False
    ) -> Tuple[List[float], StatDict]:
        return self.model_trainer.optimize(warmup=warmup)

    def populate_virtual_buffer(self):
        # pylint:disable=missing-function-docstring
        num_rollouts = self.config["model_rollouts"]
        real_data_ratio = self.config["real_data_ratio"]
        if not (num_rollouts and real_data_ratio < 1.0):
            return

        real_samples = self.replay.sample(num_rollouts)
        virtual_samples = self.generate_virtual_sample_batch(real_samples)
        self.virtual_replay.add(virtual_samples)

    def update_policy(self, times: int) -> StatDict:
        batch_size = self.config["batch_size"]
        env_batch_size = int(batch_size * self.config["real_data_ratio"])
        model_batch_size = batch_size - env_batch_size

        for _ in range(times):
            samples = []
            if env_batch_size:
                samples += [self.replay.sample(env_batch_size)]
            if model_batch_size:
                samples += [self.virtual_replay.sample(model_batch_size)]
            batch = SampleBatch.concat_samples(samples)
            batch = self.lazy_tensor_dict(batch)
            info = self.improve_policy(batch)

        return info

    def timer_stats(self) -> dict:
        stats = super().timer_stats()
        augmentation_timer = self.timers["augmentation"]
        stats.update(
            augmentation_time_s=round(augmentation_timer.mean, 3),
            augmentation_throughput=round(augmentation_timer.mean_throughput, 3),
        )
        return stats
Exemple #7
0
class OffPolicyTrainer(Trainer):
    """Generic trainer for off-policy agents."""

    # pylint: disable=attribute-defined-outside-init
    _name = ""
    _default_config = None
    _policy = None

    @override(Trainer)
    def _init(self, config, env_creator):
        self.validate_config(config)
        self.workers = self._make_workers(env_creator,
                                          self._policy,
                                          config,
                                          num_workers=0)
        self.build_replay_buffer(config)

    @override(Trainer)
    def _train(self):
        start_samples = self.sample_until_learning_starts()

        worker = self.workers.local_worker()
        policy = worker.get_policy()
        stats = {}
        while not self._iteration_done():
            samples = worker.sample()
            self.tracker.num_steps_sampled += samples.count
            for row in samples.rows():
                self.replay.add(row)
            stats.update(policy.get_exploration_info())

            self._before_replay_steps(policy)
            for _ in range(samples.count):
                batch = self.replay.sample(self.config["train_batch_size"])
                stats = get_learner_stats(policy.learn_on_batch(batch))
                self.tracker.num_steps_trained += batch.count

        self.tracker.num_steps_sampled += start_samples
        return self._log_metrics(stats)

    def build_replay_buffer(self, config):
        """Construct replay buffer to hold samples."""
        policy = self.get_policy()
        self.replay = NumpyReplayBuffer(policy.observation_space,
                                        policy.action_space,
                                        config["buffer_size"])
        self.replay.seed(config["seed"])

    def sample_until_learning_starts(self):
        """
        Sample enough transtions so that 'learning_starts' steps are collected before
        the next policy update.
        """
        learning_starts = self.config["learning_starts"]
        worker = self.workers.local_worker()
        sample_count = 0
        while self.tracker.num_steps_sampled + sample_count < learning_starts:
            samples = worker.sample()
            sample_count += samples.count
            for row in samples.rows():
                self.replay.add(row)
        return sample_count

    def _before_replay_steps(self, policy):  # pylint:disable=unused-argument
        pass

    @staticmethod
    def validate_config(config):
        """Assert configuration values are valid."""
        assert config[
            "num_workers"] == 0, "No point in using additional workers."
        assert (config["rollout_fragment_length"] >=
                1), "At least one sample must be collected."
Exemple #8
0
class OffPolicyMixin(ABC):
    """Adds a replay buffer and standard procedures for `learn_on_batch`."""

    replay: NumpyReplayBuffer

    def build_replay_buffer(self):
        """Construct the experience replay buffer.

        Should be called by subclasses on init.
        """
        self.replay = NumpyReplayBuffer(self.observation_space,
                                        self.action_space,
                                        self.config["buffer_size"])
        self.replay.seed(self.config["seed"])
        self.replay.compute_stats = self.config["std_obs"]

    @learner_stats
    def learn_on_batch(self, samples: SampleBatch):
        """Run one logical iteration of training.

        Returns:
            An info dict from this iteration.
        """
        self.add_to_buffer(samples)

        info = {}
        info.update(self.get_exploration_info())

        for _ in range(int(self.config["improvement_steps"])):
            batch = self.replay.sample(self.config["batch_size"])
            batch = self.lazy_tensor_dict(batch)
            info.update(self.improve_policy(batch))

        return info

    def add_to_buffer(self, samples: SampleBatch):
        """Add sample batch to replay buffer"""
        self.replay.add(samples)

    @abstractmethod
    def improve_policy(self, batch: TensorDict) -> dict:
        """Run one step of Policy Improvement."""

    def compute_actions(
        self,
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Union[List[TensorType], TensorType] = None,
        prev_reward_batch: Union[List[TensorType], TensorType] = None,
        info_batch: Optional[Dict[str, list]] = None,
        episodes: Optional[List[MultiAgentEpisode]] = None,
        explore: Optional[bool] = None,
        timestep: Optional[int] = None,
        **kwargs
    ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
        # pylint:disable=too-many-arguments
        obs_batch = self.replay.normalize(obs_batch)
        return super().compute_actions(
            obs_batch,
            state_batches=state_batches,
            prev_action_batch=prev_action_batch,
            prev_reward_batch=prev_reward_batch,
            info_batch=info_batch,
            episodes=episodes,
            explore=explore,
            timestep=timestep,
            **kwargs,
        )

    def compute_log_likelihoods(
        self,
        actions: Union[List[TensorType], TensorType],
        obs_batch: Union[List[TensorType], TensorType],
        state_batches: Optional[List[TensorType]] = None,
        prev_action_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
        prev_reward_batch: Optional[Union[List[TensorType],
                                          TensorType]] = None,
    ) -> TensorType:
        # pylint:disable=too-many-arguments
        obs_batch = self.replay.normalize(obs_batch)
        return super().compute_log_likelihoods(
            actions=actions,
            obs_batch=obs_batch,
            state_batches=state_batches,
            prev_action_batch=prev_action_batch,
            prev_reward_batch=prev_reward_batch,
        )

    def get_weights(self) -> dict:
        state = super().get_weights()
        state["replay"] = self.replay.state_dict()
        return state

    def set_weights(self, weights: dict):
        self.replay.load_state_dict(weights["replay"])
        super().set_weights(
            {k: v
             for k, v in weights.items() if k != "replay"})

    @staticmethod
    def add_options(policy_cls: type) -> type:
        """Decorator to add default off-policy options used by OffPolicyMixin."""
        return off_policy_options(policy_cls)