Beispiel #1
0
    def __init__(
        self, algorithm_config: ESAlgorithmConfig, torch_policy: TorchPolicy,
        shared_noise: SharedNoiseTable,
        normalization_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]]
    ) -> None:
        super().__init__(algorithm_config)

        # --- training setup ---
        self.model_selection: Optional[ModelSelectionBase] = None
        self.policy: Union[Policy, TorchModel] = torch_policy

        self.shared_noise = shared_noise
        self.normalization_stats = normalization_stats

        # setup the optimizer, now that the policy is available
        self.optimizer = Factory(Optimizer).instantiate(
            algorithm_config.optimizer)
        self.optimizer.setup(self.policy)

        # prepare statistics collection
        self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                             get_stats_logger("eval"))
        self.train_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                              get_stats_logger("train"))
        # injection of ES-specific events
        self.es_events = self.train_stats.create_event_topic(ESEvents)
    def _launch_workers(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType) \
            -> Iterable[Process]:
        """Configure the workers according to the rollout config and launch them."""
        # Split total episode count across workers
        episodes_per_process = [0] * self.n_processes
        for i in range(self.n_episodes):
            episodes_per_process[i % self.n_processes] += 1

        # Configure and launch the processes
        self.reporting_queue = Queue()
        workers = []
        for n_process_episodes in episodes_per_process:
            if n_process_episodes == 0:
                break

            p = Process(
                target=ParallelRolloutWorker.run,
                args=(env, wrappers, agent,
                      n_process_episodes, self.max_episode_steps,
                      self.record_trajectory, self.input_dir, self.reporting_queue,
                      self.maze_seeding.generate_env_instance_seed(),
                      self.maze_seeding.generate_agent_instance_seed()),
                daemon=True
            )
            p.start()
            workers.append(p)

        # Perform writer registration -- after the forks so that it is not carried over to child processes
        if self.record_event_logs:
            LogEventsWriterRegistry.register_writer(LogEventsWriterTSV(log_dir="./event_logs"))
        register_log_stats_writer(LogStatsWriterConsole())
        self.epoch_stats_aggregator = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats_aggregator.register_consumer(get_stats_logger("rollout_stats"))

        return workers
Beispiel #3
0
    def __init__(self, env: MazeEnv, logging_prefix: Optional[str] = None):
        """Avoid calling this constructor directly, use :method:`wrap` instead."""
        # BaseEnv is a subset of gym.Env
        super().__init__(env)

        # initialize step aggregator
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.episode_stats = LogStatsAggregator(LogStatsLevel.EPISODE, self.epoch_stats)
        self.step_stats = LogStatsAggregator(LogStatsLevel.STEP, self.episode_stats)

        self.stats_map = {
            LogStatsLevel.EPOCH: self.epoch_stats,
            LogStatsLevel.EPISODE: self.episode_stats,
            LogStatsLevel.STEP: self.step_stats
        }

        if logging_prefix is not None:
            self.epoch_stats.register_consumer(get_stats_logger(logging_prefix))

        self.last_env_time: Optional[int] = None
        self.reward_events = EventCollection()
        self.episode_event_log: Optional[EpisodeEventLog] = None

        self.step_stats_renderer = EventStatsRenderer()

        # register a post-step callback, so stats are recorded even in case that a wrapper
        # in the middle of the stack steps the environment (as done e.g. during step-skipping)
        if hasattr(env, "context") and isinstance(env.context, EnvironmentContext):
            env.context.register_post_step(self._record_stats_if_ready)
Beispiel #4
0
    def __init__(self,
                 env_factory: Callable[[], Union[StructuredEnv, StructuredEnvSpacesMixin, LogStatsEnv]],
                 worker_policy: TorchPolicy,
                 n_rollout_steps: int,
                 n_workers: int,
                 batch_size: int,
                 rollouts_per_iteration: int,
                 split_rollouts_into_transitions: bool,
                 env_instance_seeds: List[int],
                 replay_buffer: BaseReplayBuffer):

        self.env_factory = env_factory
        self._worker_policy = worker_policy
        self.n_rollout_steps = n_rollout_steps
        self.n_workers = n_workers
        self.batch_size = batch_size
        self.replay_buffer = replay_buffer

        self.env_instance_seeds = env_instance_seeds

        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger('train'))

        self.rollouts_per_iteration = rollouts_per_iteration
        self.split_rollouts_into_transitions = split_rollouts_into_transitions

        self._init_workers()
Beispiel #5
0
    def __init__(self,
                 n_envs: int,
                 action_spaces_dict: Dict[StepKeyType, gym.spaces.Space],
                 observation_spaces_dict: Dict[StepKeyType, gym.spaces.Space],
                 agent_counts_dict: Dict[StepKeyType, int],
                 logging_prefix: Optional[str] = None):
        super().__init__(n_envs)

        # Spaces
        self._action_spaces_dict = action_spaces_dict
        self._observation_spaces_dict = observation_spaces_dict
        self._agent_counts_dict = agent_counts_dict

        # Aggregate episode statistics from individual envs
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)

        # register a logger for the epoch statistics if desired
        if logging_prefix is not None:
            self.epoch_stats.register_consumer(
                get_stats_logger(logging_prefix))

        # Keep track of current actor IDs, actor dones, and env times (should be updated in step and reset methods).
        self._actor_ids = None
        self._actor_dones = None
        self._env_times = None
Beispiel #6
0
    def __init__(self,
                 env_factory: Callable[[], Union[StructuredEnv, StructuredEnvSpacesMixin, LogStatsEnv]],
                 policy: TorchPolicy,
                 n_rollout_steps: int,
                 n_actors: int,
                 batch_size: int):
        self.env_factory = env_factory
        self.policy = policy
        self.n_rollout_steps = n_rollout_steps
        self.n_actors = n_actors
        self.batch_size = batch_size

        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger('train'))
Beispiel #7
0
    def init_replay_buffer(replay_buffer: BaseReplayBuffer, initial_sampling_policy: Union[DictConfig, Policy],
                           initial_buffer_size: int, replay_buffer_seed: int,
                           split_rollouts_into_transitions: bool, n_rollout_steps: int,
                           env_factory: Callable[[], MazeEnv]) -> None:
        """Fill the buffer with initial_buffer_size rollouts by rolling out the initial_sampling_policy.

        :param replay_buffer: The replay buffer to use.
        :param initial_sampling_policy: The initial sampling policy used to fill the buffer to the initial fill state.
        :param initial_buffer_size: The initial size of the replay buffer filled by sampling from the initial sampling
            policy.
        :param replay_buffer_seed: A seed for initializing and sampling from the replay buffer.
        :param split_rollouts_into_transitions: Specify whether to split rollouts into individual transitions.
        :param n_rollout_steps: Number of rollouts steps to record in one rollout.
        :param env_factory: Factory function for envs to run rollouts on.
        """

        # Create the log stats aggregator for collecting kpis of initializing the replay buffer
        epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        replay_stats_logger = get_stats_logger('init_replay_buffer')
        epoch_stats.register_consumer(replay_stats_logger)

        dummy_env = env_factory()
        dummy_env.seed(replay_buffer_seed)
        sampling_policy: Policy = \
            Factory(Policy).instantiate(initial_sampling_policy, action_spaces_dict=dummy_env.action_spaces_dict)
        sampling_policy.seed(replay_buffer_seed)
        rollout_generator = RolloutGenerator(env=dummy_env,
                                             record_next_observations=True,
                                             record_episode_stats=True)

        print(f'******* Starting to fill the replay buffer with {initial_buffer_size} transitions *******')
        while len(replay_buffer) < initial_buffer_size:
            trajectory = rollout_generator.rollout(policy=sampling_policy, n_steps=n_rollout_steps)

            if split_rollouts_into_transitions:
                replay_buffer.add_rollout(trajectory)
            else:
                replay_buffer.add_transition(trajectory)

            # collect episode statistics
            for step_record in trajectory.step_records:
                if step_record.episode_stats is not None:
                    epoch_stats.receive(step_record.episode_stats)

        # Print the kpis from initializing the replay buffer
        epoch_stats.reduce()
        # Remove the consumer again from the aggregator
        epoch_stats.remove_consumer(replay_stats_logger)
Beispiel #8
0
    def __init__(self,
                 loss: BCLoss,
                 model_selection: Optional[ModelSelectionBase],
                 data_loader: DataLoader,
                 logging_prefix: Optional[str] = "eval"):
        self.loss = loss
        self.data_loader = data_loader
        self.model_selection = model_selection

        self.env = None
        if logging_prefix:
            self.eval_stats = LogStatsAggregator(
                LogStatsLevel.EPOCH, get_stats_logger(logging_prefix))
        else:
            self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.eval_events = self.eval_stats.create_event_topic(ImitationEvents)
Beispiel #9
0
    def init_logging(self, trainer_config: Dict[str, Any]) -> None:
        """Initialize logging.

        This needs to be done here as the on_train_result is not called in the main process, but in a worker process.

        Relies on the following:
          - There should be only one Callbacks object per worker process
          - The on train result should always be called in the same worker (if this is not the case,
            this system should still handle it, but it might mess things up)
        """
        assert self.epoch_stats is None, "Init logging should be called only once"

        # Local epoch stats -- stats from all envs will be collected here together
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger("train"))

        # Initialize Tensorboard and console writers
        writer = LogStatsWriterTensorboard(log_dir='.',
                                           tensorboard_render_figure=True)
        register_log_stats_writer(writer)
        register_log_stats_writer(LogStatsWriterConsole())

        summary_writer = writer.summary_writer

        # Add config to tensorboard
        yaml_config = pprint.pformat(trainer_config)
        # prepare config text for tensorboard
        yaml_config = yaml_config.replace("\n", "</br>")
        yaml_config = yaml_config.replace(" ", "&nbsp;")
        summary_writer.add_text("job_config", yaml_config)

        # Load the figures from the given files and add them to tensorboard.
        network_files = filter(lambda x: x.endswith('.figure.pkl'),
                               os.listdir('.'))
        for network_path in network_files:
            network_name = network_path.split('/')[-1].replace(
                '.figure.pkl', '')
            fig = pickle.load(open(network_path, 'rb'))
            summary_writer.add_figure(f'{network_name}', fig, close=True)
            os.remove(network_path)
Beispiel #10
0
class BCTrainer(Trainer):
    """Trainer for behavioral cloning learning.

    Runs training on top of provided trajectory data and rolls out the policy using the provided evaluator.

    In structured (multi-step) envs, all policies are trained simultaneously based on the substep actions
    and observation present in the trajectory data.
    """

    data_loader: DataLoader
    """Data loader for loading trajectory data."""

    policy: TorchPolicy
    """Structured policy to train."""

    optimizer: Optimizer
    """Optimizer to use"""

    loss: BCLoss
    """Class providing the training loss function."""

    train_stats: LogStatsAggregator = LogStatsAggregator(
        LogStatsLevel.EPOCH, get_stats_logger("train"))
    """Training statistics"""

    imitation_events: ImitationEvents = train_stats.create_event_topic(
        ImitationEvents)
    """Imitation-specific training events"""
    def __init__(self, algorithm_config: BCAlgorithmConfig,
                 data_loader: DataLoader, policy: TorchPolicy,
                 optimizer: Optimizer, loss: BCLoss):
        super().__init__(algorithm_config)

        self.data_loader = data_loader
        self.policy = policy
        self.optimizer = optimizer
        self.loss = loss

    @override(Trainer)
    def train(self,
              evaluator: Evaluator,
              n_epochs: Optional[int] = None,
              eval_every_k_iterations: Optional[int] = None) -> None:
        """
        Run training.
        :param evaluator: Evaluator to use for evaluation rollouts
        :param n_epochs: How many epochs to train for
        :param eval_every_k_iterations: Number of iterations after which to run evaluation (in addition to evaluations
        at the end of each epoch, which are run automatically). If set to None, evaluations will run on epoch end only.
        """

        if n_epochs is None:
            n_epochs = self.algorithm_config.n_epochs
        if eval_every_k_iterations is None:
            eval_every_k_iterations = self.algorithm_config.eval_every_k_iterations

        for epoch in range(n_epochs):
            print(f"\n********** Epoch {epoch + 1} started **********")
            evaluator.evaluate(self.policy)
            increment_log_step()

            for iteration, data in enumerate(self.data_loader, 0):
                observations, actions, actor_ids = data
                self._run_iteration(observations=observations,
                                    actions=actions,
                                    actor_ids=actor_ids)

                # Evaluate after each k iterations if set
                if eval_every_k_iterations is not None and \
                        iteration % eval_every_k_iterations == (eval_every_k_iterations - 1):
                    print(
                        f"\n********** Epoch {epoch + 1}: Iteration {iteration + 1} **********"
                    )
                    evaluator.evaluate(self.policy)
                    increment_log_step()

        print(f"\n********** Final evaluation **********")
        evaluator.evaluate(self.policy)
        increment_log_step()

    def load_state_dict(self, state_dict: Dict) -> None:
        """Set the model and optimizer state.
        :param state_dict: The state dict.
        """
        self.policy.load_state_dict(state_dict)

    @override(Trainer)
    def state_dict(self):
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        return self.policy.state_dict()

    @override(Trainer)
    def load_state(self, file_path: Union[str, BinaryIO]) -> None:
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        state_dict = torch.load(file_path,
                                map_location=torch.device(self.policy.device))
        self.load_state_dict(state_dict)

    def _run_iteration(self, observations: List[Union[ObservationType,
                                                      TorchObservationType]],
                       actions: List[Union[ActionType, TorchActionType]],
                       actor_ids: List[ActorID]) -> None:
        """Run a single training iterations of the behavioural cloning.

        :param observations: A list (w.r.t. the substeps/agents) of batched observations.
        :param actions: A list (w.r.t. the substeps/agents) of batched actions.
        :param actor_ids: A list (w.r.t. the substeps/agents) of the corresponding batched actor_ids.
        """
        self.policy.train()
        self.optimizer.zero_grad()

        # The actor ids of a given batch should be all the same. Thus we can debatch them.
        actor_ids = debatch_actor_ids(actor_ids)

        # Convert only actions to torch, since observations are converted in policy.compute_substep_policy_output method
        actions = convert_to_torch(actions,
                                   device=self.policy.device,
                                   cast=None,
                                   in_place=True)
        total_loss = self.loss.calculate_loss(policy=self.policy,
                                              observations=observations,
                                              actions=actions,
                                              actor_ids=actor_ids,
                                              events=self.imitation_events)
        total_loss.backward()
        self.optimizer.step()

        # Report additional policy-related stats
        for actor_id in actor_ids:
            l2_norm = sum([
                param.norm()
                for param in self.policy.network_for(actor_id).parameters()
            ])
            grad_norm = compute_gradient_norm(
                self.policy.network_for(actor_id).parameters())

            self.imitation_events.policy_l2_norm(step_id=actor_id.step_key,
                                                 agent_id=actor_id.agent_id,
                                                 value=l2_norm.item())
            self.imitation_events.policy_grad_norm(step_id=actor_id.step_key,
                                                   agent_id=actor_id.agent_id,
                                                   value=grad_norm)