Exemple #1
0
    def init_replay_buffer(replay_buffer: BaseReplayBuffer, initial_sampling_policy: Union[DictConfig, Policy],
                           initial_buffer_size: int, replay_buffer_seed: int,
                           split_rollouts_into_transitions: bool, n_rollout_steps: int,
                           env_factory: Callable[[], MazeEnv]) -> None:
        """Fill the buffer with initial_buffer_size rollouts by rolling out the initial_sampling_policy.

        :param replay_buffer: The replay buffer to use.
        :param initial_sampling_policy: The initial sampling policy used to fill the buffer to the initial fill state.
        :param initial_buffer_size: The initial size of the replay buffer filled by sampling from the initial sampling
            policy.
        :param replay_buffer_seed: A seed for initializing and sampling from the replay buffer.
        :param split_rollouts_into_transitions: Specify whether to split rollouts into individual transitions.
        :param n_rollout_steps: Number of rollouts steps to record in one rollout.
        :param env_factory: Factory function for envs to run rollouts on.
        """

        # Create the log stats aggregator for collecting kpis of initializing the replay buffer
        epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        replay_stats_logger = get_stats_logger('init_replay_buffer')
        epoch_stats.register_consumer(replay_stats_logger)

        dummy_env = env_factory()
        dummy_env.seed(replay_buffer_seed)
        sampling_policy: Policy = \
            Factory(Policy).instantiate(initial_sampling_policy, action_spaces_dict=dummy_env.action_spaces_dict)
        sampling_policy.seed(replay_buffer_seed)
        rollout_generator = RolloutGenerator(env=dummy_env,
                                             record_next_observations=True,
                                             record_episode_stats=True)

        print(f'******* Starting to fill the replay buffer with {initial_buffer_size} transitions *******')
        while len(replay_buffer) < initial_buffer_size:
            trajectory = rollout_generator.rollout(policy=sampling_policy, n_steps=n_rollout_steps)

            if split_rollouts_into_transitions:
                replay_buffer.add_rollout(trajectory)
            else:
                replay_buffer.add_transition(trajectory)

            # collect episode statistics
            for step_record in trajectory.step_records:
                if step_record.episode_stats is not None:
                    epoch_stats.receive(step_record.episode_stats)

        # Print the kpis from initializing the replay buffer
        epoch_stats.reduce()
        # Remove the consumer again from the aggregator
        epoch_stats.remove_consumer(replay_stats_logger)
Exemple #2
0
class ESTrainer(Trainer):
    """Trainer class for OpenAI Evolution Strategies.

    :param algorithm_config: Algorithm parameters.
    :param torch_policy: Multi-step policy encapsulating the policy networks
    :param shared_noise: The noise table, with the same content for every worker and the master.
    :param normalization_stats: Normalization statistics as calculated by the NormalizeObservationWrapper.
    """
    def __init__(
        self, algorithm_config: ESAlgorithmConfig, torch_policy: TorchPolicy,
        shared_noise: SharedNoiseTable,
        normalization_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]]
    ) -> None:
        super().__init__(algorithm_config)

        # --- training setup ---
        self.model_selection: Optional[ModelSelectionBase] = None
        self.policy: Union[Policy, TorchModel] = torch_policy

        self.shared_noise = shared_noise
        self.normalization_stats = normalization_stats

        # setup the optimizer, now that the policy is available
        self.optimizer = Factory(Optimizer).instantiate(
            algorithm_config.optimizer)
        self.optimizer.setup(self.policy)

        # prepare statistics collection
        self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                             get_stats_logger("eval"))
        self.train_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                              get_stats_logger("train"))
        # injection of ES-specific events
        self.es_events = self.train_stats.create_event_topic(ESEvents)

    @override(Trainer)
    def train(self,
              distributed_rollouts: ESDistributedRollouts,
              n_epochs: Optional[int] = None,
              model_selection: Optional[ModelSelectionBase] = None) -> None:
        """
        Run the ES training loop.
        :param distributed_rollouts: The distribution interface for experience collection.
        :param n_epochs: Number of epochs to train.
        :param model_selection: Optional model selection class, receives model evaluation results.
        """

        n_epochs = self.algorithm_config.n_epochs if n_epochs is None else n_epochs
        self.model_selection = model_selection

        for epoch in itertools.count():
            # check if we reached the max number of epochs
            if n_epochs and epoch == n_epochs:
                break

            print('********** Iteration {} **********'.format(epoch))

            step_start_time = time.time()

            # do the actual update step (disable autograd, as we calculate the gradient from the rollout returns)
            with torch.no_grad():
                self._update(distributed_rollouts)

            step_end_time = time.time()

            # log the step duration
            self.es_events.real_time(step_end_time - step_start_time)

            # update the epoch count
            increment_log_step()

    def load_state_dict(self, state_dict: Dict) -> None:
        """Set the model and optimizer state.
        :param state_dict: The state dict.
        """
        self.policy.load_state_dict(state_dict)

    @override(Trainer)
    def state_dict(self):
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        return self.policy.state_dict()

    @override(Trainer)
    def load_state(self, file_path: Union[str, BinaryIO]) -> None:
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        state_dict = torch.load(file_path,
                                map_location=torch.device(self.policy.device))
        self.load_state_dict(state_dict)

    def _update(self, distributed_rollouts: ESDistributedRollouts):
        # Pop off results for the current task
        n_train_episodes, n_timesteps_popped = 0, 0

        # aggregate all collected training rollouts for this episode
        epoch_results = ESRolloutResult(is_eval=False)

        # obtain a generator from the distribution interface
        rollouts_generator = distributed_rollouts.generate_rollouts(
            policy=self.policy,
            max_steps=self.algorithm_config.max_steps,
            noise_stddev=self.algorithm_config.noise_stddev,
            normalization_stats=self.normalization_stats)

        # collect eval and training rollouts
        for result in rollouts_generator:
            if result.is_eval:
                # This was an eval job
                for e in result.episode_stats:
                    self.eval_stats.receive(e)
                continue

            # we received training experience from perturbed policy networks
            epoch_results.noise_indices.extend(result.noise_indices)
            epoch_results.episode_stats.extend(result.episode_stats)

            # update the training statistics
            for e in result.episode_stats:
                self.train_stats.receive(e)

                n_train_episodes += 1
                n_timesteps_popped += e[(BaseEnvEvents.reward, "count", None)]

            # continue until we collected enough episodes and timesteps
            if (n_train_episodes >= self.algorithm_config.n_rollouts_per_update
                    and n_timesteps_popped >=
                    self.algorithm_config.n_timesteps_per_update):
                break

        # notify the model selection of the evaluation results
        eval_stats = self.eval_stats.reduce()
        if self.model_selection and len(eval_stats):
            reward = eval_stats[(BaseEnvEvents.reward, "mean", None)]
            self.model_selection.update(reward)

        # prepare returns, reshape the positive/negative antithetic estimation as (rollouts, 2)
        returns_n2 = np.array([
            e[(BaseEnvEvents.reward, "sum", None)]
            for e in epoch_results.episode_stats
        ]).reshape(-1, 2)

        # improve robustness: weight by rank, not by reward
        proc_returns_n2 = self._compute_centered_ranks(returns_n2)

        # compute the gradient
        g = self._batched_weighted_sum(
            proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
            (self.shared_noise.get(idx, self.policy.num_params)
             for idx in epoch_results.noise_indices),
            batch_size=500)

        g /= n_train_episodes / 2.0

        # apply the weight update
        theta = get_flat_parameters(self.policy)
        update_ratio = self.optimizer.update(-g +
                                             self.algorithm_config.l2_penalty *
                                             theta.numpy())

        # statistics logging
        self.es_events.update_ratio(update_ratio)

        for i in self.policy.state_dict().keys():
            self.es_events.policy_grad_norm(policy_id=i,
                                            value=np.square(g).sum()**0.5)
            self.es_events.policy_norm(policy_id=i,
                                       value=np.square(theta).sum()**0.5)

    @classmethod
    def _iter_groups(cls, items: Iterable,
                     group_size: int) -> Generator[Tuple, None, None]:
        assert group_size >= 1
        group = []
        for x in items:
            group.append(x)
            if len(group) == group_size:
                yield tuple(group)
                del group[:]
        if group:
            yield tuple(group)

    @classmethod
    def _batched_weighted_sum(cls, weights: Iterable[float],
                              vectors: Iterable[np.ndarray],
                              batch_size: int) -> np.ndarray:
        """calculate a weighted sum of the given vectors, in steps of at most `batch_size` vectors"""
        # start with float, at the first operation numpy broadcasting takes care of the correct shape
        total: Union[np.array, float] = 0.

        for batch_weights, batch_vectors in zip(
                cls._iter_groups(weights, batch_size),
                cls._iter_groups(vectors, batch_size)):
            assert len(batch_weights) == len(batch_vectors) <= batch_size
            total += np.dot(np.asarray(batch_weights, dtype=np.float32),
                            np.asarray(batch_vectors, dtype=np.float32))

        return total

    @classmethod
    def _compute_ranks(cls, x: np.ndarray) -> np.ndarray:
        """
        Returns ranks in [0, len(x))
        Note: This is different from scipy.stats.rankdata, which returns ranks in [1, len(x)].
        """
        assert x.ndim == 1
        ranks = np.empty(len(x), dtype=int)
        ranks[x.argsort()] = np.arange(len(x))
        return ranks

    @classmethod
    def _compute_centered_ranks(cls, x):
        y = cls._compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
        y /= (x.size - 1)
        y -= .5
        return y
class ParallelRolloutRunner(RolloutRunner):
    """Runs rollout in multiple processes in parallel.

    Both agent and environment are run in multiple instances across multiple processes. While this greatly speeds
    up the rollout, the memory consumption might be high for large environments and agents.

    Trajectory recording, event logging, as well as stats logging are supported. Trajectory logging happens
    in the child processes. Event logs and stats are shipped back to the main process so that they can be
    handled together there. This allows monitoring of progress and calculation of summary stats across
    all the processes.

    (Note that the relevant wrappers need to be present in the config for the trajectory/event/stats logging to work.
    Data are logged into the working directory managed by hydra.)

    In case of early rollout termination using a keyboard interrupt, data for all episodes completed till that
    point will be preserved (= written out). Graceful shutdown will be attempted, including calculation of statistics
    across the episodes completed before the rollout was terminated.

    :param n_episodes: Count of episodes to run
    :param max_episode_steps: Count of steps to run in each episode (if environment returns done, the episode
                                will be finished earlier though)
    :param n_processes: Count of processes to spread the rollout across.
    :param record_trajectory: Whether to record trajectory data
    :param record_event_logs: Whether to record event logs
    """

    def __init__(self,
                 n_episodes: int,
                 max_episode_steps: int,
                 n_processes: int,
                 record_trajectory: bool,
                 record_event_logs: bool):
        super().__init__(n_episodes, max_episode_steps, record_trajectory, record_event_logs)
        self.n_processes = n_processes
        self.epoch_stats_aggregator = None
        self.reporting_queue = None

    @override(RolloutRunner)
    def run_with(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType):
        """Run the parallel rollout in multiple worker processes."""
        workers = self._launch_workers(env, wrappers, agent)
        try:
            self._monitor_rollout(workers)
        except KeyboardInterrupt:
            self._attempt_graceful_exit(workers)

    def _launch_workers(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType) \
            -> Iterable[Process]:
        """Configure the workers according to the rollout config and launch them."""
        # Split total episode count across workers
        episodes_per_process = [0] * self.n_processes
        for i in range(self.n_episodes):
            episodes_per_process[i % self.n_processes] += 1

        # Configure and launch the processes
        self.reporting_queue = Queue()
        workers = []
        for n_process_episodes in episodes_per_process:
            if n_process_episodes == 0:
                break

            p = Process(
                target=ParallelRolloutWorker.run,
                args=(env, wrappers, agent,
                      n_process_episodes, self.max_episode_steps,
                      self.record_trajectory, self.input_dir, self.reporting_queue,
                      self.maze_seeding.generate_env_instance_seed(),
                      self.maze_seeding.generate_agent_instance_seed()),
                daemon=True
            )
            p.start()
            workers.append(p)

        # Perform writer registration -- after the forks so that it is not carried over to child processes
        if self.record_event_logs:
            LogEventsWriterRegistry.register_writer(LogEventsWriterTSV(log_dir="./event_logs"))
        register_log_stats_writer(LogStatsWriterConsole())
        self.epoch_stats_aggregator = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats_aggregator.register_consumer(get_stats_logger("rollout_stats"))

        return workers

    def _monitor_rollout(self, workers: Iterable[Process]) -> None:
        """Collect the stats and event logs from the rollout, print progress, and join the workers when done."""

        for _ in tqdm(range(self.n_episodes), desc="Episodes done", unit=" episodes"):
            report = self.reporting_queue.get()
            if isinstance(report, ExceptionReport):
                for p in workers:
                    p.terminate()
                raise RuntimeError("A worker encountered the following error:\n"
                                   + report.traceback) from report.exception

            episode_stats, episode_event_log = report
            if episode_stats is not None:
                self.epoch_stats_aggregator.receive(episode_stats)
            if episode_event_log is not None:
                LogEventsWriterRegistry.record_event_logs(episode_event_log)

        for w in workers:
            w.join()

        if len(self.epoch_stats_aggregator.input) != 0:
            self.epoch_stats_aggregator.reduce()

    def _attempt_graceful_exit(self, workers: Iterable[Process]) -> None:
        """Print statistics collected so far and exit gracefully."""

        print("\n\nShut down requested, exiting gracefully...\n")

        for w in workers:
            w.terminate()

        if len(self.epoch_stats_aggregator.input) != 0:
            print("Stats from the completed part of rollout:\n")
            self.epoch_stats_aggregator.reduce()

        print("\nRollout done (terminated prematurely).")