class SVGInfTrainer(Trainer): """Single agent trainer for SVG(inf).""" # pylint: disable=attribute-defined-outside-init _name = "SVG(inf)" _default_config = DEFAULT_CONFIG _policy = SVGInfTorchPolicy @override(Trainer) def _init(self, config, env_creator): self._validate_config(config) self.workers = self._make_workers(env_creator, self._policy, config, num_workers=config["num_workers"]) # Dummy optimizer to log stats since Trainer.collect_metrics is coupled with it self.optimizer = PolicyOptimizer(self.workers) policy = self.get_policy() policy.set_reward_from_config(config["env"], config["env_config"]) self.replay = NumpyReplayBuffer(policy.observation_space, policy.action_space, config["buffer_size"]) self.replay.add_fields(ReplayField(SampleBatch.ACTION_LOGP)) self.replay.seed(config["seed"]) @override(Trainer) def _train(self): worker = self.workers.local_worker() policy = worker.get_policy() samples = worker.sample() self.optimizer.num_steps_sampled += samples.count for row in samples.rows(): self.replay.add(row) stats = policy.get_exploration_info() with policy.learning_off_policy(): for _ in range(int(samples.count * self.config["updates_per_step"])): batch = self.replay.sample(self.config["train_batch_size"]) off_policy_stats = policy.learn_on_batch(batch) self.optimizer.num_steps_trained += batch.count stats.update(off_policy_stats) stats.update(policy.learn_on_batch(samples)) return self._log_metrics(stats)
class OffPolicyMixin(ABC): """Adds a replay buffer and standard procedures for `learn_on_batch`.""" replay: NumpyReplayBuffer def build_replay_buffer(self): """Construct the experience replay buffer. Should be called by subclasses on init. """ self.replay = NumpyReplayBuffer(self.observation_space, self.action_space, self.config["buffer_size"]) self.replay.seed(self.config["seed"]) @learner_stats def learn_on_batch(self, samples: SampleBatch): """Run one logical iteration of training. Returns: An info dict from this iteration. """ self.add_to_buffer(samples) info = {} info.update(self.get_exploration_info()) for _ in range(int(self.config["improvement_steps"])): batch = self.replay.sample(self.config["batch_size"]) batch = self.lazy_tensor_dict(batch) info.update(self.improve_policy(batch)) return info def add_to_buffer(self, samples: SampleBatch): """Add sample batch to replay buffer""" self.replay.add(samples) if self.config["std_obs"]: self.replay.update_obs_stats() @abstractmethod def improve_policy(self, batch: TensorDict) -> dict: """Run one step of Policy Improvement.""" @staticmethod def add_options(policy_cls: type) -> type: """Decorator to add default off-policy options used by OffPolicyMixin.""" return off_policy_options(policy_cls)
def numpy_replay(obs_space, action_space, size, sample_batch): replay = NumpyReplayBuffer(obs_space, action_space, size) replay.add(sample_batch) return replay
def replay(obs_space, action_space, samples): replay = NumpyReplayBuffer(obs_space, action_space, size=samples.count) replay.add(samples) return replay
class ModelBasedTrainer(OffPolicyTrainer): """Generic trainer for model-based agents.""" # pylint: disable=attribute-defined-outside-init @override(OffPolicyTrainer) def _init(self, config, env_creator): super()._init(config, env_creator) policy = self.get_policy() policy.set_reward_from_config(config["env"], config["env_config"]) policy.set_termination_from_config(config["env"], config["env_config"]) @staticmethod @override(OffPolicyTrainer) def validate_config(config): OffPolicyTrainer.validate_config(config) assert (config["holdout_ratio"] < 1.0), "Holdout data cannot be the entire dataset" assert (config["max_holdout"] >= 0), "Maximum number of holdout samples must be non-negative" assert (config["policy_improvements"] >= 0), "Number of policy improvement steps must be non-negative" assert ( 0 <= config["real_data_ratio"] <= 1 ), "Fraction of real data samples for policy improvement must be in [0, 1]" assert (config["virtual_buffer_size"] >= 0), "Virtual buffer capacity must be non-negative" assert (config["model_rollouts"] >= 0), "Cannot sample a negative number of model rollouts" @override(OffPolicyTrainer) def build_replay_buffer(self, config): super().build_replay_buffer(config) policy = self.get_policy() self.virtual_replay = NumpyReplayBuffer(policy.observation_space, policy.action_space, config["virtual_buffer_size"]) self.virtual_replay.seed(config["seed"]) @override(OffPolicyTrainer) def _train(self): start_samples = self.sample_until_learning_starts() config = self.config worker = self.workers.local_worker() policy = worker.get_policy() stats = {} while not self._iteration_done(): samples = worker.sample() self.tracker.num_steps_sampled += samples.count for row in samples.rows(): self.replay.add(row) eval_losses, model_train_info = self.train_dynamics_model() policy.setup_sampling_models(eval_losses) self.populate_virtual_buffer(config["model_rollouts"] * samples.count) policy_train_info = self.improve_policy( config["policy_improvements"] * samples.count) stats.update(model_train_info) stats.update(policy_train_info) self.tracker.num_steps_sampled += start_samples return self._log_metrics(stats) def train_dynamics_model(self) -> Tuple[List[float], Dict[str, float]]: """Implements the model training step. Calls the policy to optimize the model on the environment replay buffer. Returns: A tuple containing the list of evaluation losses for each model and a dictionary of training statistics """ samples = self.replay.all_samples() samples.shuffle() holdout = min( int(len(self.replay) * self.config["holdout_ratio"]), self.config["max_holdout"], ) train_data, eval_data = samples.slice(holdout, None), samples.slice(0, holdout) policy = self.get_policy() eval_losses, stats = policy.optimize_model(train_data, eval_data) return eval_losses, stats def populate_virtual_buffer(self, num_rollouts: int): """Add model rollouts branched from real data to the virtual pool. Args: num_rollouts: Number of initial states to samples from the environment replay buffer """ if not (num_rollouts and self.config["real_data_ratio"] < 1.0): return real_samples = self.replay.sample(num_rollouts) policy = self.get_policy() virtual_samples = policy.generate_virtual_sample_batch(real_samples) for row in virtual_samples.rows(): self.virtual_replay.add(row) def improve_policy(self, num_improvements: int) -> Dict[str, float]: """Call the policy to perform policy improvement using the augmented replay. Args: num_improvements: Number of times to call `policy.learn_on_batch` Returns: A dictionary of training and exploration statistics """ policy = self.get_policy() batch_size = self.config["train_batch_size"] env_batch_size = int(batch_size * self.config["real_data_ratio"]) model_batch_size = batch_size - env_batch_size stats = {} for _ in range(num_improvements): samples = [] if env_batch_size: samples += [self.replay.sample(env_batch_size)] if model_batch_size: samples += [self.virtual_replay.sample(model_batch_size)] batch = SampleBatch.concat_samples(samples) stats = get_learner_stats(policy.learn_on_batch(batch)) self.tracker.num_steps_trained += batch.count stats.update(policy.get_exploration_info()) return stats
class MBPOTorchPolicy(MBPolicyMixin, EnvFnMixin, ModelSamplingMixin, SACTorchPolicy): """Model-Based Policy Optimization policy in PyTorch to use with RLlib.""" # pylint:disable=too-many-ancestors virtual_replay: NumpyReplayBuffer model_trainer: LightningModelTrainer dist_class = WrapStochasticPolicy def __init__(self, observation_space, action_space, config): super().__init__(observation_space, action_space, config) models = self.module.models self.loss_model = MaximumLikelihood(models) self.build_timers() self.model_trainer = LightningModelTrainer( models=self.module.models, loss_fn=self.loss_model, optimizer=self.optimizers["models"], replay=self.replay, config=self.config, ) def _make_optimizers(self): optimizers = super()._make_optimizers() config = self.config["optimizer"] optimizers["models"] = build_optimizer(self.module.models, config["models"]) return optimizers def build_replay_buffer(self): super().build_replay_buffer() self.virtual_replay = NumpyReplayBuffer( self.observation_space, self.action_space, self.config["virtual_buffer_size"], ) self.virtual_replay.seed(self.config["seed"]) def build_timers(self): super().build_timers() self.timers["augmentation"] = TimerStat() @learner_stats def learn_on_batch(self, samples: SampleBatch) -> dict: self.add_to_buffer(samples) self._learn_calls += 1 info = {} warmup = self._learn_calls == 1 if self._learn_calls % self.config["model_update_interval"] == 0 or warmup: with self.timers["model"] as timer: losses, model_info = self.train_dynamics_model(warmup=warmup) timer.push_units_processed(model_info["model_epochs"]) info.update(model_info) self.set_new_elite(losses) with self.timers["augmentation"] as timer: count_before = len(self.virtual_replay) self.populate_virtual_buffer() timer.push_units_processed(len(self.virtual_replay) - count_before) with self.timers["policy"] as timer: times = self.config["improvement_steps"] policy_info = self.update_policy(times=times) timer.push_units_processed(times) info.update(policy_info) info.update(self.timer_stats()) return info def train_dynamics_model( self, warmup: bool = False ) -> Tuple[List[float], StatDict]: return self.model_trainer.optimize(warmup=warmup) def populate_virtual_buffer(self): # pylint:disable=missing-function-docstring num_rollouts = self.config["model_rollouts"] real_data_ratio = self.config["real_data_ratio"] if not (num_rollouts and real_data_ratio < 1.0): return real_samples = self.replay.sample(num_rollouts) virtual_samples = self.generate_virtual_sample_batch(real_samples) self.virtual_replay.add(virtual_samples) def update_policy(self, times: int) -> StatDict: batch_size = self.config["batch_size"] env_batch_size = int(batch_size * self.config["real_data_ratio"]) model_batch_size = batch_size - env_batch_size for _ in range(times): samples = [] if env_batch_size: samples += [self.replay.sample(env_batch_size)] if model_batch_size: samples += [self.virtual_replay.sample(model_batch_size)] batch = SampleBatch.concat_samples(samples) batch = self.lazy_tensor_dict(batch) info = self.improve_policy(batch) return info def timer_stats(self) -> dict: stats = super().timer_stats() augmentation_timer = self.timers["augmentation"] stats.update( augmentation_time_s=round(augmentation_timer.mean, 3), augmentation_throughput=round(augmentation_timer.mean_throughput, 3), ) return stats
class OffPolicyTrainer(Trainer): """Generic trainer for off-policy agents.""" # pylint: disable=attribute-defined-outside-init _name = "" _default_config = None _policy = None @override(Trainer) def _init(self, config, env_creator): self.validate_config(config) self.workers = self._make_workers(env_creator, self._policy, config, num_workers=0) self.build_replay_buffer(config) @override(Trainer) def _train(self): start_samples = self.sample_until_learning_starts() worker = self.workers.local_worker() policy = worker.get_policy() stats = {} while not self._iteration_done(): samples = worker.sample() self.tracker.num_steps_sampled += samples.count for row in samples.rows(): self.replay.add(row) stats.update(policy.get_exploration_info()) self._before_replay_steps(policy) for _ in range(samples.count): batch = self.replay.sample(self.config["train_batch_size"]) stats = get_learner_stats(policy.learn_on_batch(batch)) self.tracker.num_steps_trained += batch.count self.tracker.num_steps_sampled += start_samples return self._log_metrics(stats) def build_replay_buffer(self, config): """Construct replay buffer to hold samples.""" policy = self.get_policy() self.replay = NumpyReplayBuffer(policy.observation_space, policy.action_space, config["buffer_size"]) self.replay.seed(config["seed"]) def sample_until_learning_starts(self): """ Sample enough transtions so that 'learning_starts' steps are collected before the next policy update. """ learning_starts = self.config["learning_starts"] worker = self.workers.local_worker() sample_count = 0 while self.tracker.num_steps_sampled + sample_count < learning_starts: samples = worker.sample() sample_count += samples.count for row in samples.rows(): self.replay.add(row) return sample_count def _before_replay_steps(self, policy): # pylint:disable=unused-argument pass @staticmethod def validate_config(config): """Assert configuration values are valid.""" assert config[ "num_workers"] == 0, "No point in using additional workers." assert (config["rollout_fragment_length"] >= 1), "At least one sample must be collected."
class OffPolicyMixin(ABC): """Adds a replay buffer and standard procedures for `learn_on_batch`.""" replay: NumpyReplayBuffer def build_replay_buffer(self): """Construct the experience replay buffer. Should be called by subclasses on init. """ self.replay = NumpyReplayBuffer(self.observation_space, self.action_space, self.config["buffer_size"]) self.replay.seed(self.config["seed"]) self.replay.compute_stats = self.config["std_obs"] @learner_stats def learn_on_batch(self, samples: SampleBatch): """Run one logical iteration of training. Returns: An info dict from this iteration. """ self.add_to_buffer(samples) info = {} info.update(self.get_exploration_info()) for _ in range(int(self.config["improvement_steps"])): batch = self.replay.sample(self.config["batch_size"]) batch = self.lazy_tensor_dict(batch) info.update(self.improve_policy(batch)) return info def add_to_buffer(self, samples: SampleBatch): """Add sample batch to replay buffer""" self.replay.add(samples) @abstractmethod def improve_policy(self, batch: TensorDict) -> dict: """Run one step of Policy Improvement.""" def compute_actions( self, obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None, info_batch: Optional[Dict[str, list]] = None, episodes: Optional[List[MultiAgentEpisode]] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: # pylint:disable=too-many-arguments obs_batch = self.replay.normalize(obs_batch) return super().compute_actions( obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, info_batch=info_batch, episodes=episodes, explore=explore, timestep=timestep, **kwargs, ) def compute_log_likelihoods( self, actions: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None, ) -> TensorType: # pylint:disable=too-many-arguments obs_batch = self.replay.normalize(obs_batch) return super().compute_log_likelihoods( actions=actions, obs_batch=obs_batch, state_batches=state_batches, prev_action_batch=prev_action_batch, prev_reward_batch=prev_reward_batch, ) def get_weights(self) -> dict: state = super().get_weights() state["replay"] = self.replay.state_dict() return state def set_weights(self, weights: dict): self.replay.load_state_dict(weights["replay"]) super().set_weights( {k: v for k, v in weights.items() if k != "replay"}) @staticmethod def add_options(policy_cls: type) -> type: """Decorator to add default off-policy options used by OffPolicyMixin.""" return off_policy_options(policy_cls)