示例#1
0
    def __init__(self,
                 env_factory: Callable[[], Union[StructuredEnv, StructuredEnvSpacesMixin, LogStatsEnv]],
                 worker_policy: TorchPolicy,
                 n_rollout_steps: int,
                 n_workers: int,
                 batch_size: int,
                 rollouts_per_iteration: int,
                 split_rollouts_into_transitions: bool,
                 env_instance_seeds: List[int],
                 replay_buffer: BaseReplayBuffer):

        self.env_factory = env_factory
        self._worker_policy = worker_policy
        self.n_rollout_steps = n_rollout_steps
        self.n_workers = n_workers
        self.batch_size = batch_size
        self.replay_buffer = replay_buffer

        self.env_instance_seeds = env_instance_seeds

        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger('train'))

        self.rollouts_per_iteration = rollouts_per_iteration
        self.split_rollouts_into_transitions = split_rollouts_into_transitions

        self._init_workers()
示例#2
0
    def __init__(
        self, algorithm_config: ESAlgorithmConfig, torch_policy: TorchPolicy,
        shared_noise: SharedNoiseTable,
        normalization_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]]
    ) -> None:
        super().__init__(algorithm_config)

        # --- training setup ---
        self.model_selection: Optional[ModelSelectionBase] = None
        self.policy: Union[Policy, TorchModel] = torch_policy

        self.shared_noise = shared_noise
        self.normalization_stats = normalization_stats

        # setup the optimizer, now that the policy is available
        self.optimizer = Factory(Optimizer).instantiate(
            algorithm_config.optimizer)
        self.optimizer.setup(self.policy)

        # prepare statistics collection
        self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                             get_stats_logger("eval"))
        self.train_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                              get_stats_logger("train"))
        # injection of ES-specific events
        self.es_events = self.train_stats.create_event_topic(ESEvents)
示例#3
0
    def __init__(self,
                 n_envs: int,
                 action_spaces_dict: Dict[StepKeyType, gym.spaces.Space],
                 observation_spaces_dict: Dict[StepKeyType, gym.spaces.Space],
                 agent_counts_dict: Dict[StepKeyType, int],
                 logging_prefix: Optional[str] = None):
        super().__init__(n_envs)

        # Spaces
        self._action_spaces_dict = action_spaces_dict
        self._observation_spaces_dict = observation_spaces_dict
        self._agent_counts_dict = agent_counts_dict

        # Aggregate episode statistics from individual envs
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)

        # register a logger for the epoch statistics if desired
        if logging_prefix is not None:
            self.epoch_stats.register_consumer(
                get_stats_logger(logging_prefix))

        # Keep track of current actor IDs, actor dones, and env times (should be updated in step and reset methods).
        self._actor_ids = None
        self._actor_dones = None
        self._env_times = None
示例#4
0
    def _launch_workers(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType) \
            -> Iterable[Process]:
        """Configure the workers according to the rollout config and launch them."""
        # Split total episode count across workers
        episodes_per_process = [0] * self.n_processes
        for i in range(self.n_episodes):
            episodes_per_process[i % self.n_processes] += 1

        # Configure and launch the processes
        self.reporting_queue = Queue()
        workers = []
        for n_process_episodes in episodes_per_process:
            if n_process_episodes == 0:
                break

            p = Process(
                target=ParallelRolloutWorker.run,
                args=(env, wrappers, agent,
                      n_process_episodes, self.max_episode_steps,
                      self.record_trajectory, self.input_dir, self.reporting_queue,
                      self.maze_seeding.generate_env_instance_seed(),
                      self.maze_seeding.generate_agent_instance_seed()),
                daemon=True
            )
            p.start()
            workers.append(p)

        # Perform writer registration -- after the forks so that it is not carried over to child processes
        if self.record_event_logs:
            LogEventsWriterRegistry.register_writer(LogEventsWriterTSV(log_dir="./event_logs"))
        register_log_stats_writer(LogStatsWriterConsole())
        self.epoch_stats_aggregator = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats_aggregator.register_consumer(get_stats_logger("rollout_stats"))

        return workers
示例#5
0
    def __init__(self,
                 env_factory: Callable[[], Union[StructuredEnv, StructuredEnvSpacesMixin, LogStatsEnv]],
                 policy: TorchPolicy,
                 n_rollout_steps: int,
                 n_actors: int,
                 batch_size: int):
        self.env_factory = env_factory
        self.policy = policy
        self.n_rollout_steps = n_rollout_steps
        self.n_actors = n_actors
        self.batch_size = batch_size

        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger('train'))
示例#6
0
    def init_replay_buffer(replay_buffer: BaseReplayBuffer, initial_sampling_policy: Union[DictConfig, Policy],
                           initial_buffer_size: int, replay_buffer_seed: int,
                           split_rollouts_into_transitions: bool, n_rollout_steps: int,
                           env_factory: Callable[[], MazeEnv]) -> None:
        """Fill the buffer with initial_buffer_size rollouts by rolling out the initial_sampling_policy.

        :param replay_buffer: The replay buffer to use.
        :param initial_sampling_policy: The initial sampling policy used to fill the buffer to the initial fill state.
        :param initial_buffer_size: The initial size of the replay buffer filled by sampling from the initial sampling
            policy.
        :param replay_buffer_seed: A seed for initializing and sampling from the replay buffer.
        :param split_rollouts_into_transitions: Specify whether to split rollouts into individual transitions.
        :param n_rollout_steps: Number of rollouts steps to record in one rollout.
        :param env_factory: Factory function for envs to run rollouts on.
        """

        # Create the log stats aggregator for collecting kpis of initializing the replay buffer
        epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        replay_stats_logger = get_stats_logger('init_replay_buffer')
        epoch_stats.register_consumer(replay_stats_logger)

        dummy_env = env_factory()
        dummy_env.seed(replay_buffer_seed)
        sampling_policy: Policy = \
            Factory(Policy).instantiate(initial_sampling_policy, action_spaces_dict=dummy_env.action_spaces_dict)
        sampling_policy.seed(replay_buffer_seed)
        rollout_generator = RolloutGenerator(env=dummy_env,
                                             record_next_observations=True,
                                             record_episode_stats=True)

        print(f'******* Starting to fill the replay buffer with {initial_buffer_size} transitions *******')
        while len(replay_buffer) < initial_buffer_size:
            trajectory = rollout_generator.rollout(policy=sampling_policy, n_steps=n_rollout_steps)

            if split_rollouts_into_transitions:
                replay_buffer.add_rollout(trajectory)
            else:
                replay_buffer.add_transition(trajectory)

            # collect episode statistics
            for step_record in trajectory.step_records:
                if step_record.episode_stats is not None:
                    epoch_stats.receive(step_record.episode_stats)

        # Print the kpis from initializing the replay buffer
        epoch_stats.reduce()
        # Remove the consumer again from the aggregator
        epoch_stats.remove_consumer(replay_stats_logger)
示例#7
0
    def __init__(self,
                 loss: BCLoss,
                 model_selection: Optional[ModelSelectionBase],
                 data_loader: DataLoader,
                 logging_prefix: Optional[str] = "eval"):
        self.loss = loss
        self.data_loader = data_loader
        self.model_selection = model_selection

        self.env = None
        if logging_prefix:
            self.eval_stats = LogStatsAggregator(
                LogStatsLevel.EPOCH, get_stats_logger(logging_prefix))
        else:
            self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.eval_events = self.eval_stats.create_event_topic(ImitationEvents)
示例#8
0
def test_multi_group_projection():
    """ test grouping by three attributes """

    class _EventInterface(ABC):
        @define_stats_grouping("group1", "group2", "group3")
        @define_step_stats(sum, group_by="group1", output_name="g1")
        @define_step_stats(sum, group_by="group2", output_name="g2")
        @define_step_stats(sum, group_by="group3", output_name="g3")
        def event1(self, group1, group2, group3, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group1=1, group2=0, group3=0, attr1=1)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group1=0, group2=1, group3=0, attr1=2)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group1=0, group2=0, group3=1, attr1=4)))

    stats = agg.reduce()
    assert len(stats) == 6

    assert stats[(_EventInterface.event1, "g1", (0, None, None))] == 6
    assert stats[(_EventInterface.event1, "g1", (1, None, None))] == 1
    assert stats[(_EventInterface.event1, "g2", (None, 0, None))] == 5
    assert stats[(_EventInterface.event1, "g2", (None, 1, None))] == 2
    assert stats[(_EventInterface.event1, "g3", (None, None, 0))] == 3
    assert stats[(_EventInterface.event1, "g3", (None, None, 1))] == 4
示例#9
0
def test_event_single_attribute():
    """ test if the aggregation function receives scalars if there is only a single event attribute """

    class _EventInterface(ABC):
        @define_step_stats(sum)
        def event1(self, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2)))

    stats = agg.reduce()
    assert len(stats) == 1

    key, value = next(iter(stats.items()))
    assert value == 3
    # tuple (event, output name)
    assert key == (_EventInterface.event1, None, None)
示例#10
0
def test_event_counting():
    """ test counting as a simple aggregation that operates on the attributes dict """

    class _EventInterface(ABC):
        @define_step_stats(len)
        def event1(self, attr1, attr2):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))

    stats = agg.reduce()
    assert len(stats) == 1

    key, value = list(stats.items())[0]
    assert value == 2
    # tuple (event, output name)
    assert key == (_EventInterface.event1, None, None)
示例#11
0
class BCValidationEvaluator(Evaluator):
    """Evaluates a given policy on validation data.

    Expects that the first two items returned in the dataset tuple are the observation_dict and action_dict.

    :param data_loader: The data used for evaluation.
    :param loss: Loss function to be used.
    :param model_selection: Model selection interface that will be notified of the recorded rewards.
    """
    def __init__(self,
                 loss: BCLoss,
                 model_selection: Optional[ModelSelectionBase],
                 data_loader: DataLoader,
                 logging_prefix: Optional[str] = "eval"):
        self.loss = loss
        self.data_loader = data_loader
        self.model_selection = model_selection

        self.env = None
        if logging_prefix:
            self.eval_stats = LogStatsAggregator(
                LogStatsLevel.EPOCH, get_stats_logger(logging_prefix))
        else:
            self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.eval_events = self.eval_stats.create_event_topic(ImitationEvents)

    @override(Evaluator)
    def evaluate(self, policy: TorchPolicy) -> None:
        """Evaluate given policy (results are stored in stat logs) and dump the model if the reward improved.

        :param policy: Policy to evaluate.
        """
        policy.eval()
        with torch.no_grad():
            total_loss = []

            for iteration, data in enumerate(self.data_loader, 0):
                observations, actions, actor_ids = data[0], data[1], data[-1]
                actor_ids = debatch_actor_ids(actor_ids)
                # Convert only actions to torch, since observations are converted in
                # policy.compute_substep_policy_output method
                convert_to_torch(actions,
                                 device=policy.device,
                                 cast=None,
                                 in_place=True)

                total_loss.append(
                    self.loss.calculate_loss(policy=policy,
                                             observations=observations,
                                             actions=actions,
                                             events=self.eval_events,
                                             actor_ids=actor_ids).item())

            if self.model_selection:
                self.model_selection.update(-np.mean(total_loss).item())
示例#12
0
def test_event_stats_histogram_2():
    """ test histogram loggin on an event level """

    class _EventInterface(ABC):
        @define_step_stats(histogram, input_name='attr1')
        @define_step_stats(histogram, input_name='attr2')
        def event1(self, attr1, attr2):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=2)))

    stats = agg.reduce()
    assert len(stats) == 2

    value1 = stats[(_EventInterface.event1, "attr1", None)]
    value2 = stats[(_EventInterface.event1, "attr2", None)]

    assert value1 == [1, 1]
    assert value2 == [2, 2]
示例#13
0
def test_event_attributes():
    """ test the aggregation of individual event attributes """

    class _EventInterface(ABC):
        @define_step_stats(sum, input_name='attr1')
        @define_step_stats(sum, input_name='attr2')
        def event1(self, attr1, attr2):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=3)))
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=1, attr2=3)))

    stats = agg.reduce()
    assert len(stats) == 2

    value1 = stats[(_EventInterface.event1, "attr1", None)]
    value2 = stats[(_EventInterface.event1, "attr2", None)]

    assert value1 == 2
    assert value2 == 6
示例#14
0
    def __init__(self, env: MazeEnv, logging_prefix: Optional[str] = None):
        """Avoid calling this constructor directly, use :method:`wrap` instead."""
        # BaseEnv is a subset of gym.Env
        super().__init__(env)

        # initialize step aggregator
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.episode_stats = LogStatsAggregator(LogStatsLevel.EPISODE, self.epoch_stats)
        self.step_stats = LogStatsAggregator(LogStatsLevel.STEP, self.episode_stats)

        self.stats_map = {
            LogStatsLevel.EPOCH: self.epoch_stats,
            LogStatsLevel.EPISODE: self.episode_stats,
            LogStatsLevel.STEP: self.step_stats
        }

        if logging_prefix is not None:
            self.epoch_stats.register_consumer(get_stats_logger(logging_prefix))

        self.last_env_time: Optional[int] = None
        self.reward_events = EventCollection()
        self.episode_event_log: Optional[EpisodeEventLog] = None

        self.step_stats_renderer = EventStatsRenderer()

        # register a post-step callback, so stats are recorded even in case that a wrapper
        # in the middle of the stack steps the environment (as done e.g. during step-skipping)
        if hasattr(env, "context") and isinstance(env.context, EnvironmentContext):
            env.context.register_post_step(self._record_stats_if_ready)
示例#15
0
    def init_logging(self, trainer_config: Dict[str, Any]) -> None:
        """Initialize logging.

        This needs to be done here as the on_train_result is not called in the main process, but in a worker process.

        Relies on the following:
          - There should be only one Callbacks object per worker process
          - The on train result should always be called in the same worker (if this is not the case,
            this system should still handle it, but it might mess things up)
        """
        assert self.epoch_stats is None, "Init logging should be called only once"

        # Local epoch stats -- stats from all envs will be collected here together
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger("train"))

        # Initialize Tensorboard and console writers
        writer = LogStatsWriterTensorboard(log_dir='.',
                                           tensorboard_render_figure=True)
        register_log_stats_writer(writer)
        register_log_stats_writer(LogStatsWriterConsole())

        summary_writer = writer.summary_writer

        # Add config to tensorboard
        yaml_config = pprint.pformat(trainer_config)
        # prepare config text for tensorboard
        yaml_config = yaml_config.replace("\n", "</br>")
        yaml_config = yaml_config.replace(" ", "&nbsp;")
        summary_writer.add_text("job_config", yaml_config)

        # Load the figures from the given files and add them to tensorboard.
        network_files = filter(lambda x: x.endswith('.figure.pkl'),
                               os.listdir('.'))
        for network_path in network_files:
            network_name = network_path.split('/')[-1].replace(
                '.figure.pkl', '')
            fig = pickle.load(open(network_path, 'rb'))
            summary_writer.add_figure(f'{network_name}', fig, close=True)
            os.remove(network_path)
示例#16
0
def test_multi_grouping():
    """ test grouping by three attributes """

    class _EventInterface(ABC):
        @define_stats_grouping("group1", "group2", "group3")
        @define_step_stats(sum)
        def event1(self, group1, group2, group3, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    for i in [1, 8]:
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1,
                                  dict(group1=1, group2=0, group3=0, attr1=1 * i)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1,
                                  dict(group1=0, group2=1, group3=0, attr1=2 * i)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1,
                                  dict(group1=0, group2=0, group3=1, attr1=4 * i)))

    stats = agg.reduce()
    assert len(stats) == 3

    assert stats[(_EventInterface.event1, None, (1, 0, 0))] == 9
    assert stats[(_EventInterface.event1, None, (0, 1, 0))] == 18
    assert stats[(_EventInterface.event1, None, (0, 0, 1))] == 36
示例#17
0
def test_event_skip_aggregation():
    """ test the once-per-step logging """

    class _EventInterface(ABC):
        @define_step_stats(None)
        def event1(self, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))

    stats = agg.reduce()
    assert len(stats) == 1

    key, value = next(iter(stats.items()))
    assert value == 3
    # tuple (event, output name)
    assert key == (_EventInterface.event1, None, None)

    # check if multiple calls per step are correctly detected
    with pytest.raises(AssertionError):
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))
        agg.reduce()
示例#18
0
def test_grouping():
    """ test the aggregation of individual event attributes """

    class _EventInterface(ABC):
        @define_stats_grouping("group")
        @define_step_stats(sum)
        def event1(self, group, attr1):
            pass

    agg = LogStatsAggregator(LogStatsLevel.STEP)
    for v in [1, 3]:
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group=0, attr1=v)))
        agg.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(group=1, attr1=v * 2)))

    stats = agg.reduce()
    assert len(stats) == 2

    value1 = stats[(_EventInterface.event1, None, (0,))]
    value2 = stats[(_EventInterface.event1, None, (1,))]

    assert value1 == 4
    assert value2 == 8
示例#19
0
class LogStatsWrapper(Wrapper[MazeEnv], LogStatsEnv):
    """A statistics logging wrapper for :class:`~maze.core.env.base_env.BaseEnv`.

    :param env: The environment to wrap.
    """

    def __init__(self, env: MazeEnv, logging_prefix: Optional[str] = None):
        """Avoid calling this constructor directly, use :method:`wrap` instead."""
        # BaseEnv is a subset of gym.Env
        super().__init__(env)

        # initialize step aggregator
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.episode_stats = LogStatsAggregator(LogStatsLevel.EPISODE, self.epoch_stats)
        self.step_stats = LogStatsAggregator(LogStatsLevel.STEP, self.episode_stats)

        self.stats_map = {
            LogStatsLevel.EPOCH: self.epoch_stats,
            LogStatsLevel.EPISODE: self.episode_stats,
            LogStatsLevel.STEP: self.step_stats
        }

        if logging_prefix is not None:
            self.epoch_stats.register_consumer(get_stats_logger(logging_prefix))

        self.last_env_time: Optional[int] = None
        self.reward_events = EventCollection()
        self.episode_event_log: Optional[EpisodeEventLog] = None

        self.step_stats_renderer = EventStatsRenderer()

        # register a post-step callback, so stats are recorded even in case that a wrapper
        # in the middle of the stack steps the environment (as done e.g. during step-skipping)
        if hasattr(env, "context") and isinstance(env.context, EnvironmentContext):
            env.context.register_post_step(self._record_stats_if_ready)

    T = TypeVar("T")

    @classmethod
    def wrap(cls, env: T, logging_prefix: Optional[str] = None) -> Union[T, LogStatsEnv]:
        """Creation method providing appropriate type hints. Preferred method to construct the wrapper
        compared to calling the class constructor directly.

        :param env: The environment to be wrapped.
        :param logging_prefix: The episode statistics is connected to the logging system with this tagging
                               prefix. If None, no logging happens.

        :return A newly created wrapper instance.
        """
        return cls(env, logging_prefix)

    @override(BaseEnv)
    def step(self, action: Any) -> Tuple[Any, Any, bool, Dict[Any, Any]]:
        """Collect the rewards for the logging statistics
        """

        # get identifier of current substep
        substep_id, _ = self.env.actor_id() if isinstance(self.env, StructuredEnv) else (None, None)

        # take core env step
        obs, rew, done, info = self.env.step(action)

        # record the reward
        self.reward_events.append(EventRecord(BaseEnvEvents, BaseEnvEvents.reward, dict(value=rew)))

        self._record_stats_if_ready()

        return obs, rew, done, info

    def _record_stats_if_ready(self) -> None:
        """Checks if stats are ready to record based on env time (for structured envs, we wait till the end
        of the whole structured step) and if so, does the recording.
        """
        if self.last_env_time is None:
            self.last_env_time = self.env.initial_env_time

        # Recording of event logs and stats happens:
        #  - for TimeEnvs:   Only if the env time changed, so that we record once per time step
        #  - for other envs: Every step
        if isinstance(self.env, TimeEnvMixin) and self.env.get_env_time() == self.last_env_time:
            return

        step_event_log = StepEventLog(env_time=self.last_env_time, events=self.reward_events)
        self.reward_events = EventCollection()

        if isinstance(self.env, EventEnvMixin):
            step_event_log.extend(self.env.get_step_events())

        # add all recorded events to the step aggregator
        for event_record in step_event_log.events:
            self.step_stats.add_event(event_record)

        # trigger logging statistics calculation
        self.step_stats.reduce()

        # lazy init new episode event log if needed
        if not self.episode_event_log:
            episode_id = self.env.get_episode_id() if isinstance(self.env, RecordableEnvMixin) else str(uuid.uuid4())
            self.episode_event_log = EpisodeEventLog(episode_id)

        # log raw events and init new step log
        self.episode_event_log.step_event_logs.append(step_event_log)

        # update the time of last stats recording
        self.last_env_time = self.env.get_env_time() if isinstance(self.env, TimeEnvMixin) else self.last_env_time + 1

    @override(BaseEnv)
    def reset(self) -> Any:
        """Reset the environment and trigger the episode statistics calculation of the previous run.
        """
        # Generate the episode stats from the previous rollout if any
        self._calculate_kpis()
        self.episode_stats.reduce()
        self._write_episode_event_log()

        # Initialize recording for the new episode (so we can record events already during env reset)
        self.last_env_time = None
        self.reward_events = EventCollection()

        return self.env.reset()

    @override(BaseEnv)
    def close(self):
        """Close the stats rendering figure if needed."""
        self.step_stats_renderer.close()

    @override(LogStatsEnv)
    def get_stats(self, level: LogStatsLevel) -> LogStatsAggregator:
        """Implementation of the LogStatsEnv interface, return the statistics aggregator."""
        aggregator = self.stats_map[level]
        return aggregator

    @override(LogStatsEnv)
    def write_epoch_stats(self):
        """Implementation of the LogStatsEnv interface, call reduce on the episode aggregator.
        """
        if self.episode_event_log:
            self._calculate_kpis()
            self.episode_stats.reduce()
        self.epoch_stats.reduce()
        self._write_episode_event_log()
        self.episode_event_log = None

    @override(LogStatsEnv)
    def get_stats_value(self,
                        event: Callable,
                        level: LogStatsLevel,
                        name: Optional[str] = None) -> LogStatsValue:
        """Implementation of the LogStatsEnv interface, obtain the value from the cached aggregator statistics.
        """
        return self.epoch_stats.last_stats[(event, name, None)]

    @override(LogStatsEnv)
    def clear_epoch_stats(self) -> None:
        """Implementation of the LogStatsEnv interface, clear out episode statistics collected so far in this epoch."""
        self.epoch_stats.clear_inputs()

    def render_stats(self,
                     event_name: str = "BaseEnvEvents.reward",
                     metric_name: str = "value",
                     aggregation_func: Optional[Union[str, Callable]] = None,
                     group_by: str = None,
                     post_processing_func: Optional[Union[str, Callable]] = 'cumsum'):
        """Render statistics from the currently running episode.

        Rendering is based on event logs. You can select arbitrary events from those dispatched by the currently
        running environment.

        :param event_name: Name of the even the even log corresponds to
        :param metric_name: Metric to use (one of the event attributes, e.g. "n_items" -- depends on the event type)
        :param aggregation_func: Optionally, specifies how to aggregate the metric on step level, i.e. when there
                                 are multiple same events dispatched during the same step.
        :param group_by: Optionally, another of event attributes to group by on the step level (e.g. "product_id")
        :param post_processing_func: Optionally, a function to post-process the data ("cumsum" is often used)"""
        self.step_stats_renderer.render_current_episode_stats(
            self.episode_event_log, event_name, metric_name,
            aggregation_func, group_by, post_processing_func)

    def _calculate_kpis(self):
        """Calculate KPIs and append them to both aggregated and logged events."""
        if not isinstance(self.env, EventEnvMixin) or not self.episode_event_log:
            return

        kpi_calculator = self.env.get_kpi_calculator()
        if kpi_calculator is None:
            return

        last_maze_state = self.env.get_maze_state() if isinstance(self.env, RecordableEnvMixin) else None

        kpis_dict = kpi_calculator.calculate_kpis(self.episode_event_log, last_maze_state)
        kpi_events = []
        for name, value in kpis_dict.items():
            kpi_events.append(EventRecord(BaseEnvEvents, BaseEnvEvents.kpi, dict(name=name, value=value)))

        for event_record in kpi_events:
            self.episode_stats.add_event(event_record)  # Add the events to episode aggregator
            self.episode_event_log.step_event_logs[-1].events.append(event_record)  # Log the events

    def _write_episode_event_log(self):
        """Send the episode event log to writers."""
        if self.episode_event_log:
            LogEventsWriterRegistry.record_event_logs(self.episode_event_log)

        self.episode_event_log = None

    @override(Wrapper)
    def get_observation_and_action_dicts(self, maze_state: Optional[MazeStateType],
                                         maze_action: Optional[MazeActionType],
                                         first_step_in_episode: bool) \
            -> Tuple[Optional[Dict[Union[int, str], Any]], Optional[Dict[Union[int, str], Any]]]:
        """Keep both actions and observation the same."""
        return self.env.get_observation_and_action_dicts(maze_state, maze_action, first_step_in_episode)

    @override(SimulatedEnvMixin)
    def clone_from(self, env: 'LogStatsWrapper') -> None:
        """implementation of :class:`~maze.core.env.simulated_env_mixin.SimulatedEnvMixin`."""
        raise RuntimeError("Cloning the 'LogStatsWrapper' is not supported.")

    def get_last_step_events(self, query: Union[Callable, Iterable[Callable]] = None):
        """Convenience accessor to all events recorded during the last step.

        :param query: Specify which events to return (one or more interface methods)
        :return: Recorded events from the last step (all if no query is present)
        """
        if not self.episode_event_log or len(self.episode_event_log.step_event_logs) == 0:
            return []

        last_step_log = self.episode_event_log.step_event_logs[-1]
        if query:
            return list(last_step_log.events.query_events(query))
        else:
            return list(last_step_log.events.events)
示例#20
0
class BCTrainer(Trainer):
    """Trainer for behavioral cloning learning.

    Runs training on top of provided trajectory data and rolls out the policy using the provided evaluator.

    In structured (multi-step) envs, all policies are trained simultaneously based on the substep actions
    and observation present in the trajectory data.
    """

    data_loader: DataLoader
    """Data loader for loading trajectory data."""

    policy: TorchPolicy
    """Structured policy to train."""

    optimizer: Optimizer
    """Optimizer to use"""

    loss: BCLoss
    """Class providing the training loss function."""

    train_stats: LogStatsAggregator = LogStatsAggregator(
        LogStatsLevel.EPOCH, get_stats_logger("train"))
    """Training statistics"""

    imitation_events: ImitationEvents = train_stats.create_event_topic(
        ImitationEvents)
    """Imitation-specific training events"""
    def __init__(self, algorithm_config: BCAlgorithmConfig,
                 data_loader: DataLoader, policy: TorchPolicy,
                 optimizer: Optimizer, loss: BCLoss):
        super().__init__(algorithm_config)

        self.data_loader = data_loader
        self.policy = policy
        self.optimizer = optimizer
        self.loss = loss

    @override(Trainer)
    def train(self,
              evaluator: Evaluator,
              n_epochs: Optional[int] = None,
              eval_every_k_iterations: Optional[int] = None) -> None:
        """
        Run training.
        :param evaluator: Evaluator to use for evaluation rollouts
        :param n_epochs: How many epochs to train for
        :param eval_every_k_iterations: Number of iterations after which to run evaluation (in addition to evaluations
        at the end of each epoch, which are run automatically). If set to None, evaluations will run on epoch end only.
        """

        if n_epochs is None:
            n_epochs = self.algorithm_config.n_epochs
        if eval_every_k_iterations is None:
            eval_every_k_iterations = self.algorithm_config.eval_every_k_iterations

        for epoch in range(n_epochs):
            print(f"\n********** Epoch {epoch + 1} started **********")
            evaluator.evaluate(self.policy)
            increment_log_step()

            for iteration, data in enumerate(self.data_loader, 0):
                observations, actions, actor_ids = data
                self._run_iteration(observations=observations,
                                    actions=actions,
                                    actor_ids=actor_ids)

                # Evaluate after each k iterations if set
                if eval_every_k_iterations is not None and \
                        iteration % eval_every_k_iterations == (eval_every_k_iterations - 1):
                    print(
                        f"\n********** Epoch {epoch + 1}: Iteration {iteration + 1} **********"
                    )
                    evaluator.evaluate(self.policy)
                    increment_log_step()

        print(f"\n********** Final evaluation **********")
        evaluator.evaluate(self.policy)
        increment_log_step()

    def load_state_dict(self, state_dict: Dict) -> None:
        """Set the model and optimizer state.
        :param state_dict: The state dict.
        """
        self.policy.load_state_dict(state_dict)

    @override(Trainer)
    def state_dict(self):
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        return self.policy.state_dict()

    @override(Trainer)
    def load_state(self, file_path: Union[str, BinaryIO]) -> None:
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        state_dict = torch.load(file_path,
                                map_location=torch.device(self.policy.device))
        self.load_state_dict(state_dict)

    def _run_iteration(self, observations: List[Union[ObservationType,
                                                      TorchObservationType]],
                       actions: List[Union[ActionType, TorchActionType]],
                       actor_ids: List[ActorID]) -> None:
        """Run a single training iterations of the behavioural cloning.

        :param observations: A list (w.r.t. the substeps/agents) of batched observations.
        :param actions: A list (w.r.t. the substeps/agents) of batched actions.
        :param actor_ids: A list (w.r.t. the substeps/agents) of the corresponding batched actor_ids.
        """
        self.policy.train()
        self.optimizer.zero_grad()

        # The actor ids of a given batch should be all the same. Thus we can debatch them.
        actor_ids = debatch_actor_ids(actor_ids)

        # Convert only actions to torch, since observations are converted in policy.compute_substep_policy_output method
        actions = convert_to_torch(actions,
                                   device=self.policy.device,
                                   cast=None,
                                   in_place=True)
        total_loss = self.loss.calculate_loss(policy=self.policy,
                                              observations=observations,
                                              actions=actions,
                                              actor_ids=actor_ids,
                                              events=self.imitation_events)
        total_loss.backward()
        self.optimizer.step()

        # Report additional policy-related stats
        for actor_id in actor_ids:
            l2_norm = sum([
                param.norm()
                for param in self.policy.network_for(actor_id).parameters()
            ])
            grad_norm = compute_gradient_norm(
                self.policy.network_for(actor_id).parameters())

            self.imitation_events.policy_l2_norm(step_id=actor_id.step_key,
                                                 agent_id=actor_id.agent_id,
                                                 value=l2_norm.item())
            self.imitation_events.policy_grad_norm(step_id=actor_id.step_key,
                                                   agent_id=actor_id.agent_id,
                                                   value=grad_norm)
示例#21
0
class MazeRLlibLoggingCallbacks(DefaultCallbacks):
    """Callbacks to enable Maze-style logging."""
    def __init__(self):
        super().__init__()
        self.epoch_stats = None

    def init_logging(self, trainer_config: Dict[str, Any]) -> None:
        """Initialize logging.

        This needs to be done here as the on_train_result is not called in the main process, but in a worker process.

        Relies on the following:
          - There should be only one Callbacks object per worker process
          - The on train result should always be called in the same worker (if this is not the case,
            this system should still handle it, but it might mess things up)
        """
        assert self.epoch_stats is None, "Init logging should be called only once"

        # Local epoch stats -- stats from all envs will be collected here together
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger("train"))

        # Initialize Tensorboard and console writers
        writer = LogStatsWriterTensorboard(log_dir='.',
                                           tensorboard_render_figure=True)
        register_log_stats_writer(writer)
        register_log_stats_writer(LogStatsWriterConsole())

        summary_writer = writer.summary_writer

        # Add config to tensorboard
        yaml_config = pprint.pformat(trainer_config)
        # prepare config text for tensorboard
        yaml_config = yaml_config.replace("\n", "</br>")
        yaml_config = yaml_config.replace(" ", "&nbsp;")
        summary_writer.add_text("job_config", yaml_config)

        # Load the figures from the given files and add them to tensorboard.
        network_files = filter(lambda x: x.endswith('.figure.pkl'),
                               os.listdir('.'))
        for network_path in network_files:
            network_name = network_path.split('/')[-1].replace(
                '.figure.pkl', '')
            fig = pickle.load(open(network_path, 'rb'))
            summary_writer.add_figure(f'{network_name}', fig, close=True)
            os.remove(network_path)

    def on_train_result(self, trainer: Trainer, result: dict,
                        **kwargs) -> None:
        """Aggregates stats of all rollouts in one local aggregator and then writes them out.
        Called at the end of Trainable.train().

        :param trainer: Current model instance.
        :param result: Dict of results returned from model.train() call.
            You can mutate this object to add additional metrics.
        :param kwargs: Forward compatibility placeholder.
        """

        # Initialize the logging for this process if not done yet
        if self.epoch_stats is None:
            print("Initializing logging of train results")
            self.init_logging(trainer.config)

        # The main local aggregator should be empty
        #  - No stats should be collected here until we manually add them
        #  - Stats from the last call should be cleared out already (written out to the logs)
        assert self.epoch_stats.input == {}, "input should be empty at the beginning"

        # Get the epoch stats from the individual rollouts
        epoch_aggregators = trainer.workers.foreach_worker(
            lambda worker: worker.foreach_env(lambda env: env.get_stats(
                LogStatsLevel.EPOCH)))

        # Collect all episode stats from the epoch aggregators of individual rollout envs in the main local aggregator
        for worker_epoch_aggregator in epoch_aggregators:
            for env_epoch_aggregator in worker_epoch_aggregator:
                # Pass stats from the individual env runs into the main epoch aggregator
                for stats_key, stats_value in env_epoch_aggregator.input.items(
                ):
                    self.epoch_stats.input[stats_key].extend(stats_value)

        # clear logs at distributed workers
        def reset_episode_stats(env) -> None:
            """Empty inputs of the individual aggregators and make sure they don't have any consumers"""
            epoch_aggregator = env.get_stats(LogStatsLevel.EPOCH)
            epoch_aggregator.input = defaultdict(list)
            epoch_aggregator.consumers = []

        trainer.workers.foreach_worker(lambda worker: worker.foreach_env(
            lambda env: reset_episode_stats(env)))

        # Increment log step to trigger epoch logging
        increment_log_step()
示例#22
0
class DistributedActors:
    """The base class for all distributed actors.

    Distributed actors run rollouts independently. Rollouts are recorded and made available in batches
    to be used during training. When a new policy version is made available, it is distributed to all actors.

    :param env_factory: Factory function for envs to run rollouts on.
    :param policy: Structured policy to sample actions from.
    :param n_rollout_steps: Number of rollouts steps to record in one rollout.
    :param n_actors: Number of distributed actors to run simultaneously.
    :param batch_size: Size of the batch the rollouts are collected in.
    """

    def __init__(self,
                 env_factory: Callable[[], Union[StructuredEnv, StructuredEnvSpacesMixin, LogStatsEnv]],
                 policy: TorchPolicy,
                 n_rollout_steps: int,
                 n_actors: int,
                 batch_size: int):
        self.env_factory = env_factory
        self.policy = policy
        self.n_rollout_steps = n_rollout_steps
        self.n_actors = n_actors
        self.batch_size = batch_size

        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger('train'))

    @abstractmethod
    def start(self) -> None:
        """Start all distributed actors"""
        raise NotImplementedError

    @abstractmethod
    def stop(self) -> None:
        """Stop all distributed actors"""
        raise NotImplementedError

    @abstractmethod
    def broadcast_updated_policy(self, state_dict: Dict) -> None:
        """Broadcast the newest version of the policy to the actors.

        :param state_dict: State of the new policy version to broadcast."""
        raise NotImplementedError

    @abstractmethod
    def collect_outputs(self, learner_device: str) -> Tuple[StructuredSpacesRecord, float, float, float]:
        """Collect `self.batch_size` actor outputs from the queue and return them batched where the first dim is
        time and the second is the batch size.

        :param learner_device: the device of the learner
        :return: A tuple of (1) batched version of ActorOutputs, (2) queue size before de-queueing,
                 (3) queue size after dequeueing, and (4) the time it took to dequeue the outputs
        """
        raise NotImplementedError

    def get_epoch_stats_aggregator(self) -> LogStatsAggregator:
        """Return the collected epoch stats aggregator"""
        return self.epoch_stats

    def get_stats_value(self,
                        event: Callable,
                        level: LogStatsLevel,
                        name: Optional[str] = None) -> LogStatsValue:
        """Obtain a single value from the epoch statistics dict.

        :param event: The event interface method of the value in question.
        :param name: The *output_name* of the statistics in case it has been specified in
                     :func:`maze.core.log_stats.event_decorators.define_epoch_stats`
        :param level: Must be set to `LogStatsLevel.EPOCH`, step or episode statistics are not propagated.
        """
        assert level == LogStatsLevel.EPOCH

        return self.epoch_stats.last_stats[(event, name, None)]
示例#23
0
class ParallelRolloutRunner(RolloutRunner):
    """Runs rollout in multiple processes in parallel.

    Both agent and environment are run in multiple instances across multiple processes. While this greatly speeds
    up the rollout, the memory consumption might be high for large environments and agents.

    Trajectory recording, event logging, as well as stats logging are supported. Trajectory logging happens
    in the child processes. Event logs and stats are shipped back to the main process so that they can be
    handled together there. This allows monitoring of progress and calculation of summary stats across
    all the processes.

    (Note that the relevant wrappers need to be present in the config for the trajectory/event/stats logging to work.
    Data are logged into the working directory managed by hydra.)

    In case of early rollout termination using a keyboard interrupt, data for all episodes completed till that
    point will be preserved (= written out). Graceful shutdown will be attempted, including calculation of statistics
    across the episodes completed before the rollout was terminated.

    :param n_episodes: Count of episodes to run
    :param max_episode_steps: Count of steps to run in each episode (if environment returns done, the episode
                                will be finished earlier though)
    :param n_processes: Count of processes to spread the rollout across.
    :param record_trajectory: Whether to record trajectory data
    :param record_event_logs: Whether to record event logs
    """

    def __init__(self,
                 n_episodes: int,
                 max_episode_steps: int,
                 n_processes: int,
                 record_trajectory: bool,
                 record_event_logs: bool):
        super().__init__(n_episodes, max_episode_steps, record_trajectory, record_event_logs)
        self.n_processes = n_processes
        self.epoch_stats_aggregator = None
        self.reporting_queue = None

    @override(RolloutRunner)
    def run_with(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType):
        """Run the parallel rollout in multiple worker processes."""
        workers = self._launch_workers(env, wrappers, agent)
        try:
            self._monitor_rollout(workers)
        except KeyboardInterrupt:
            self._attempt_graceful_exit(workers)

    def _launch_workers(self, env: ConfigType, wrappers: CollectionOfConfigType, agent: ConfigType) \
            -> Iterable[Process]:
        """Configure the workers according to the rollout config and launch them."""
        # Split total episode count across workers
        episodes_per_process = [0] * self.n_processes
        for i in range(self.n_episodes):
            episodes_per_process[i % self.n_processes] += 1

        # Configure and launch the processes
        self.reporting_queue = Queue()
        workers = []
        for n_process_episodes in episodes_per_process:
            if n_process_episodes == 0:
                break

            p = Process(
                target=ParallelRolloutWorker.run,
                args=(env, wrappers, agent,
                      n_process_episodes, self.max_episode_steps,
                      self.record_trajectory, self.input_dir, self.reporting_queue,
                      self.maze_seeding.generate_env_instance_seed(),
                      self.maze_seeding.generate_agent_instance_seed()),
                daemon=True
            )
            p.start()
            workers.append(p)

        # Perform writer registration -- after the forks so that it is not carried over to child processes
        if self.record_event_logs:
            LogEventsWriterRegistry.register_writer(LogEventsWriterTSV(log_dir="./event_logs"))
        register_log_stats_writer(LogStatsWriterConsole())
        self.epoch_stats_aggregator = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats_aggregator.register_consumer(get_stats_logger("rollout_stats"))

        return workers

    def _monitor_rollout(self, workers: Iterable[Process]) -> None:
        """Collect the stats and event logs from the rollout, print progress, and join the workers when done."""

        for _ in tqdm(range(self.n_episodes), desc="Episodes done", unit=" episodes"):
            report = self.reporting_queue.get()
            if isinstance(report, ExceptionReport):
                for p in workers:
                    p.terminate()
                raise RuntimeError("A worker encountered the following error:\n"
                                   + report.traceback) from report.exception

            episode_stats, episode_event_log = report
            if episode_stats is not None:
                self.epoch_stats_aggregator.receive(episode_stats)
            if episode_event_log is not None:
                LogEventsWriterRegistry.record_event_logs(episode_event_log)

        for w in workers:
            w.join()

        if len(self.epoch_stats_aggregator.input) != 0:
            self.epoch_stats_aggregator.reduce()

    def _attempt_graceful_exit(self, workers: Iterable[Process]) -> None:
        """Print statistics collected so far and exit gracefully."""

        print("\n\nShut down requested, exiting gracefully...\n")

        for w in workers:
            w.terminate()

        if len(self.epoch_stats_aggregator.input) != 0:
            print("Stats from the completed part of rollout:\n")
            self.epoch_stats_aggregator.reduce()

        print("\nRollout done (terminated prematurely).")
示例#24
0
class StructuredVectorEnv(VectorEnv, StructuredEnv, StructuredEnvSpacesMixin,
                          LogStatsEnv, TimeEnvMixin, ABC):
    """Common superclass for the structured vectorised env implementations in Maze.

    :param n_envs: The number of vectorised environments.
    :param action_spaces_dict: Action spaces dict (not vectorized, as it is the same for all environments)
    :param observation_spaces_dict: Observation spaces dict (not vectorized, as it is the same for all environments)
    :param logging_prefix: If set, will report epoch statistics under this logging prefix.
    """
    def __init__(self,
                 n_envs: int,
                 action_spaces_dict: Dict[StepKeyType, gym.spaces.Space],
                 observation_spaces_dict: Dict[StepKeyType, gym.spaces.Space],
                 agent_counts_dict: Dict[StepKeyType, int],
                 logging_prefix: Optional[str] = None):
        super().__init__(n_envs)

        # Spaces
        self._action_spaces_dict = action_spaces_dict
        self._observation_spaces_dict = observation_spaces_dict
        self._agent_counts_dict = agent_counts_dict

        # Aggregate episode statistics from individual envs
        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)

        # register a logger for the epoch statistics if desired
        if logging_prefix is not None:
            self.epoch_stats.register_consumer(
                get_stats_logger(logging_prefix))

        # Keep track of current actor IDs, actor dones, and env times (should be updated in step and reset methods).
        self._actor_ids = None
        self._actor_dones = None
        self._env_times = None

    @override(StructuredEnv)
    def actor_id(self) -> ActorID:
        """Current actor ID (should be the same for all envs, as only synchronous envs are supported)."""
        assert len(set(self._actor_ids)
                   ) == 1, "only synchronous environments are supported."
        return self._actor_ids[0]

    @property
    def agent_counts_dict(self) -> Dict[StepKeyType, int]:
        """Return the agent counts of one of the vectorised envs."""
        return self._agent_counts_dict

    @override(StructuredEnv)
    def is_actor_done(self) -> np.ndarray:
        """Return the done flags of all actors in a list."""
        return self._actor_dones

    @abstractmethod
    @override(StructuredEnv)
    def get_actor_rewards(self) -> Optional[np.ndarray]:
        """Individual implementations need to override this to support structured rewards."""

    @override(TimeEnvMixin)
    def get_env_time(self) -> np.ndarray:
        """Return current env time for all vectorised environments."""
        return self._env_times

    @property
    def action_spaces_dict(self) -> Dict[Union[int, str], gym.spaces.Space]:
        """Return the action space of one of the vectorised envs."""
        return self._action_spaces_dict

    @property
    def observation_spaces_dict(
            self) -> Dict[Union[int, str], gym.spaces.Space]:
        """Return the observation space of one of the vectorised envs."""
        return self._observation_spaces_dict

    @property
    @override(StructuredEnvSpacesMixin)
    def action_space(self) -> gym.spaces.Space:
        """implementation of :class:`~maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin` interface
        """
        sub_step_id, _ = self.actor_id()
        return self.action_spaces_dict[sub_step_id]

    @property
    @override(StructuredEnvSpacesMixin)
    def observation_space(self) -> gym.spaces.Space:
        """implementation of :class:`~maze.core.env.structured_env_spaces_mixin.StructuredEnvSpacesMixin` interface
        """
        sub_step_id, _ = self.actor_id()
        return self.observation_spaces_dict[sub_step_id]

    @override(LogStatsEnv)
    def get_stats(self, level: LogStatsLevel) -> LogStatsAggregator:
        """Returns the aggregator of the individual episode statistics emitted by the parallel envs.

        :param level: Must be set to `LogStatsLevel.EPOCH`, step or episode statistics are not propagated
        """
        assert level == LogStatsLevel.EPOCH
        return self.epoch_stats

    @override(LogStatsEnv)
    def write_epoch_stats(self):
        """Trigger the epoch statistics generation."""
        self.epoch_stats.reduce()

    @override(LogStatsEnv)
    def clear_epoch_stats(self) -> None:
        """Clear out episode statistics collected so far in this epoch."""
        self.epoch_stats.clear_inputs()

    @override(LogStatsEnv)
    def get_stats_value(self,
                        event: Callable,
                        level: LogStatsLevel,
                        name: Optional[str] = None) -> LogStatsValue:
        """Obtain a single value from the epoch statistics dict.

        :param event: The event interface method of the value in question.
        :param name: The *output_name* of the statistics in case it has been specified in
                     :func:`maze.core.log_stats.event_decorators.define_epoch_stats`
        :param level: Must be set to `LogStatsLevel.EPOCH`, step or episode statistics are not propagated.
        """
        assert level == LogStatsLevel.EPOCH
        return self.epoch_stats.last_stats[(event, name, None)]
示例#25
0
def test_aggregation_chain_fork():
    """ test the aggregation chain with two event attributes and different aggregation operations """

    class _EventInterface(ABC):
        @define_epoch_stats(sum, input_name="attr1_sum")
        @define_epoch_stats(np.mean, input_name="attr2_mean")
        @define_episode_stats(sum, input_name="attr1_sum")
        @define_episode_stats(np.mean, input_name="attr2_mean")
        @define_step_stats(sum, input_name="attr1", output_name="attr1_sum")
        @define_step_stats(np.mean, input_name="attr1", output_name="attr1_mean")
        @define_step_stats(sum, input_name="attr2", output_name="attr2_sum")
        @define_step_stats(np.mean, input_name="attr2", output_name="attr2_mean")
        def event1(self, attr1, attr2):
            pass

    agg_episode = LogStatsAggregator(LogStatsLevel.EPOCH)
    agg_step = LogStatsAggregator(LogStatsLevel.EPISODE, agg_episode)
    agg_event = LogStatsAggregator(LogStatsLevel.STEP, agg_step)

    no_steps = 5
    no_episodes = 7
    for episode in range(no_episodes):
        for step in range(no_steps):
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2.0, attr2=-2.0)))
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3.0, attr2=-3.0)))

            step_stats = agg_event.reduce()
            assert len(step_stats) == 4
            value1_sum = step_stats[(_EventInterface.event1, "attr1_sum", None)]
            value1_mean = step_stats[(_EventInterface.event1, "attr1_mean", None)]
            value2_sum = step_stats[(_EventInterface.event1, "attr2_sum", None)]
            value2_mean = step_stats[(_EventInterface.event1, "attr2_mean", None)]
            assert value1_sum == 5.0
            assert value1_mean == 2.5
            assert value2_sum == -5.0
            assert value2_mean == -2.5

        episode_stats = agg_step.reduce()
        assert len(episode_stats) == 2
        value1 = episode_stats[(_EventInterface.event1, "attr1_sum", None)]
        value2 = episode_stats[(_EventInterface.event1, "attr2_mean", None)]
        assert value1 == no_steps * 5.0
        assert value2 == -2.5

    epoch_stats = agg_episode.reduce()
    assert len(epoch_stats) == 2
    value1 = epoch_stats[(_EventInterface.event1, "attr1_sum", None)]
    value2 = epoch_stats[(_EventInterface.event1, "attr2_mean", None)]
    assert value1 == no_episodes * no_steps * 5.0
    assert value2 == -2.5
示例#26
0
class ESTrainer(Trainer):
    """Trainer class for OpenAI Evolution Strategies.

    :param algorithm_config: Algorithm parameters.
    :param torch_policy: Multi-step policy encapsulating the policy networks
    :param shared_noise: The noise table, with the same content for every worker and the master.
    :param normalization_stats: Normalization statistics as calculated by the NormalizeObservationWrapper.
    """
    def __init__(
        self, algorithm_config: ESAlgorithmConfig, torch_policy: TorchPolicy,
        shared_noise: SharedNoiseTable,
        normalization_stats: Optional[Dict[str, Tuple[np.ndarray, np.ndarray]]]
    ) -> None:
        super().__init__(algorithm_config)

        # --- training setup ---
        self.model_selection: Optional[ModelSelectionBase] = None
        self.policy: Union[Policy, TorchModel] = torch_policy

        self.shared_noise = shared_noise
        self.normalization_stats = normalization_stats

        # setup the optimizer, now that the policy is available
        self.optimizer = Factory(Optimizer).instantiate(
            algorithm_config.optimizer)
        self.optimizer.setup(self.policy)

        # prepare statistics collection
        self.eval_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                             get_stats_logger("eval"))
        self.train_stats = LogStatsAggregator(LogStatsLevel.EPOCH,
                                              get_stats_logger("train"))
        # injection of ES-specific events
        self.es_events = self.train_stats.create_event_topic(ESEvents)

    @override(Trainer)
    def train(self,
              distributed_rollouts: ESDistributedRollouts,
              n_epochs: Optional[int] = None,
              model_selection: Optional[ModelSelectionBase] = None) -> None:
        """
        Run the ES training loop.
        :param distributed_rollouts: The distribution interface for experience collection.
        :param n_epochs: Number of epochs to train.
        :param model_selection: Optional model selection class, receives model evaluation results.
        """

        n_epochs = self.algorithm_config.n_epochs if n_epochs is None else n_epochs
        self.model_selection = model_selection

        for epoch in itertools.count():
            # check if we reached the max number of epochs
            if n_epochs and epoch == n_epochs:
                break

            print('********** Iteration {} **********'.format(epoch))

            step_start_time = time.time()

            # do the actual update step (disable autograd, as we calculate the gradient from the rollout returns)
            with torch.no_grad():
                self._update(distributed_rollouts)

            step_end_time = time.time()

            # log the step duration
            self.es_events.real_time(step_end_time - step_start_time)

            # update the epoch count
            increment_log_step()

    def load_state_dict(self, state_dict: Dict) -> None:
        """Set the model and optimizer state.
        :param state_dict: The state dict.
        """
        self.policy.load_state_dict(state_dict)

    @override(Trainer)
    def state_dict(self):
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        return self.policy.state_dict()

    @override(Trainer)
    def load_state(self, file_path: Union[str, BinaryIO]) -> None:
        """implementation of :class:`~maze.train.trainers.common.trainer.Trainer`
        """
        state_dict = torch.load(file_path,
                                map_location=torch.device(self.policy.device))
        self.load_state_dict(state_dict)

    def _update(self, distributed_rollouts: ESDistributedRollouts):
        # Pop off results for the current task
        n_train_episodes, n_timesteps_popped = 0, 0

        # aggregate all collected training rollouts for this episode
        epoch_results = ESRolloutResult(is_eval=False)

        # obtain a generator from the distribution interface
        rollouts_generator = distributed_rollouts.generate_rollouts(
            policy=self.policy,
            max_steps=self.algorithm_config.max_steps,
            noise_stddev=self.algorithm_config.noise_stddev,
            normalization_stats=self.normalization_stats)

        # collect eval and training rollouts
        for result in rollouts_generator:
            if result.is_eval:
                # This was an eval job
                for e in result.episode_stats:
                    self.eval_stats.receive(e)
                continue

            # we received training experience from perturbed policy networks
            epoch_results.noise_indices.extend(result.noise_indices)
            epoch_results.episode_stats.extend(result.episode_stats)

            # update the training statistics
            for e in result.episode_stats:
                self.train_stats.receive(e)

                n_train_episodes += 1
                n_timesteps_popped += e[(BaseEnvEvents.reward, "count", None)]

            # continue until we collected enough episodes and timesteps
            if (n_train_episodes >= self.algorithm_config.n_rollouts_per_update
                    and n_timesteps_popped >=
                    self.algorithm_config.n_timesteps_per_update):
                break

        # notify the model selection of the evaluation results
        eval_stats = self.eval_stats.reduce()
        if self.model_selection and len(eval_stats):
            reward = eval_stats[(BaseEnvEvents.reward, "mean", None)]
            self.model_selection.update(reward)

        # prepare returns, reshape the positive/negative antithetic estimation as (rollouts, 2)
        returns_n2 = np.array([
            e[(BaseEnvEvents.reward, "sum", None)]
            for e in epoch_results.episode_stats
        ]).reshape(-1, 2)

        # improve robustness: weight by rank, not by reward
        proc_returns_n2 = self._compute_centered_ranks(returns_n2)

        # compute the gradient
        g = self._batched_weighted_sum(
            proc_returns_n2[:, 0] - proc_returns_n2[:, 1],
            (self.shared_noise.get(idx, self.policy.num_params)
             for idx in epoch_results.noise_indices),
            batch_size=500)

        g /= n_train_episodes / 2.0

        # apply the weight update
        theta = get_flat_parameters(self.policy)
        update_ratio = self.optimizer.update(-g +
                                             self.algorithm_config.l2_penalty *
                                             theta.numpy())

        # statistics logging
        self.es_events.update_ratio(update_ratio)

        for i in self.policy.state_dict().keys():
            self.es_events.policy_grad_norm(policy_id=i,
                                            value=np.square(g).sum()**0.5)
            self.es_events.policy_norm(policy_id=i,
                                       value=np.square(theta).sum()**0.5)

    @classmethod
    def _iter_groups(cls, items: Iterable,
                     group_size: int) -> Generator[Tuple, None, None]:
        assert group_size >= 1
        group = []
        for x in items:
            group.append(x)
            if len(group) == group_size:
                yield tuple(group)
                del group[:]
        if group:
            yield tuple(group)

    @classmethod
    def _batched_weighted_sum(cls, weights: Iterable[float],
                              vectors: Iterable[np.ndarray],
                              batch_size: int) -> np.ndarray:
        """calculate a weighted sum of the given vectors, in steps of at most `batch_size` vectors"""
        # start with float, at the first operation numpy broadcasting takes care of the correct shape
        total: Union[np.array, float] = 0.

        for batch_weights, batch_vectors in zip(
                cls._iter_groups(weights, batch_size),
                cls._iter_groups(vectors, batch_size)):
            assert len(batch_weights) == len(batch_vectors) <= batch_size
            total += np.dot(np.asarray(batch_weights, dtype=np.float32),
                            np.asarray(batch_vectors, dtype=np.float32))

        return total

    @classmethod
    def _compute_ranks(cls, x: np.ndarray) -> np.ndarray:
        """
        Returns ranks in [0, len(x))
        Note: This is different from scipy.stats.rankdata, which returns ranks in [1, len(x)].
        """
        assert x.ndim == 1
        ranks = np.empty(len(x), dtype=int)
        ranks[x.argsort()] = np.arange(len(x))
        return ranks

    @classmethod
    def _compute_centered_ranks(cls, x):
        y = cls._compute_ranks(x.ravel()).reshape(x.shape).astype(np.float32)
        y /= (x.size - 1)
        y -= .5
        return y
示例#27
0
class BaseDistributedWorkersWithBuffer:
    """The base class for all distributed workers with buffer.

    Distributed workers run rollouts independently. Rollouts are collected by calling the collect_rollouts method and
    are then added to the buffer.

    :param env_factory: Factory function for envs to run rollouts on
    :param worker_policy: Structured policy to sample actions from
    :param n_rollout_steps: Number of rollouts steps to record in one rollout
    :param n_workers: Number of distributed workers to run simultaneously
    :param batch_size: Size of the batch the rollouts are collected in
    :param rollouts_per_iteration: The number of rollouts to collect each time the collect_rollouts method is called.
    :param split_rollouts_into_transitions: Specify whether all computed rollouts should be split into
                                            transitions before processing them
    :param env_instance_seeds: A list of seeds for each workers envs.
    :param replay_buffer: The replay buffer to use.
    """

    def __init__(self,
                 env_factory: Callable[[], Union[StructuredEnv, StructuredEnvSpacesMixin, LogStatsEnv]],
                 worker_policy: TorchPolicy,
                 n_rollout_steps: int,
                 n_workers: int,
                 batch_size: int,
                 rollouts_per_iteration: int,
                 split_rollouts_into_transitions: bool,
                 env_instance_seeds: List[int],
                 replay_buffer: BaseReplayBuffer):

        self.env_factory = env_factory
        self._worker_policy = worker_policy
        self.n_rollout_steps = n_rollout_steps
        self.n_workers = n_workers
        self.batch_size = batch_size
        self.replay_buffer = replay_buffer

        self.env_instance_seeds = env_instance_seeds

        self.epoch_stats = LogStatsAggregator(LogStatsLevel.EPOCH)
        self.epoch_stats.register_consumer(get_stats_logger('train'))

        self.rollouts_per_iteration = rollouts_per_iteration
        self.split_rollouts_into_transitions = split_rollouts_into_transitions

        self._init_workers()

    @abstractmethod
    def _init_workers(self):
        """Init the agents based on the kind of distribution used"""

    @abstractmethod
    def start(self) -> None:
        """Start all distributed workers"""

    @abstractmethod
    def stop(self) -> None:
        """Stop all distributed workers"""

    def __del__(self) -> None:
        """If the module is deleted stop the workers"""
        self.stop()

    @abstractmethod
    def broadcast_updated_policy(self, state_dict: Dict) -> None:
        """Broadcast the newest version of the policy to the workers.

        :param state_dict: State of the new policy version to broadcast.
        """

    def sample_batch(self, learner_device: str) -> StructuredSpacesRecord:
        """Sample a batch from the buffer and return it as a batched structured spaces record.

        :param learner_device: The device of the learner (cpu or cuda).
        :return: An batched structured spaces record object holding the batched rollouts.
        """
        batch = self.replay_buffer.sample_batch(n_samples=self.batch_size, learner_device=learner_device)
        if self.split_rollouts_into_transitions:
            # Stack records into one, then add an additional dimension
            stacked_records = StructuredSpacesRecord.stack_records(batch)
            return StructuredSpacesRecord.stack_records([stacked_records]).to_torch(learner_device)
        else:
            # Stack trajectories in time major, then stack into a single spaces record
            return SpacesTrajectoryRecord.stack_trajectories(batch).stack().to_torch(learner_device)

    @abstractmethod
    def collect_rollouts(self) -> Tuple[float, float, float]:
        """Collect worker outputs from the queue and add it to the buffer.

        :return: A tuple of (1) queue size before de-queueing,
                 (2) queue size after dequeueing, and (3) the time it took to dequeue the outputs
        """
        raise NotImplementedError

    def get_epoch_stats_aggregator(self) -> LogStatsAggregator:
        """Return the collected epoch stats aggregator"""
        return self.epoch_stats

    def get_stats_value(self,
                        event: Callable,
                        level: LogStatsLevel,
                        name: Optional[str] = None) -> LogStatsValue:
        """Obtain a single value from the epoch statistics dict.

        :param event: The event interface method of the value in question.
        :param name: The *output_name* of the statistics in case it has been specified in
                     :func:`maze.core.log_stats.event_decorators.define_epoch_stats`
        :param level: Must be set to `LogStatsLevel.EPOCH`, step or episode statistics are not propagated.
        """
        assert level == LogStatsLevel.EPOCH

        return self.epoch_stats.last_stats[(event, name, None)]
示例#28
0
def test_aggregation_chain():
    """ test the aggregation chain with a single event attribute """

    class _EventInterface(ABC):
        @define_epoch_stats(sum)
        @define_episode_stats(sum)
        @define_step_stats(sum)
        def event1(self, attr1):
            pass

    agg_episode = LogStatsAggregator(LogStatsLevel.EPOCH)
    agg_step = LogStatsAggregator(LogStatsLevel.EPISODE, agg_episode)
    agg_event = LogStatsAggregator(LogStatsLevel.STEP, agg_step)

    no_steps = 5
    no_episodes = 7
    for episode in range(no_episodes):
        for step in range(no_steps):
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2)))
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3)))
            agg_event.reduce()

        episode_stats = agg_step.reduce()
        assert len(episode_stats) == 1
        value = episode_stats[(_EventInterface.event1, None, None)]
        assert value == no_steps * 5

    epoch_stats = agg_episode.reduce()
    assert len(epoch_stats) == 1
    value = epoch_stats[(_EventInterface.event1, None, None)]
    assert value == no_episodes * no_steps * 5
示例#29
0
def test_aggregation_chain_multi_attribute():
    """ test the aggregation chain with two event attributes """

    class _EventInterface(ABC):
        @define_epoch_stats(sum, input_name="attr1")
        @define_epoch_stats(sum, input_name="attr2")
        @define_episode_stats(sum, input_name="attr1")
        @define_episode_stats(sum, input_name="attr2")
        @define_step_stats(sum, input_name="attr1")
        @define_step_stats(sum, input_name="attr2")
        def event1(self, attr1, attr2):
            pass

    agg_episode = LogStatsAggregator(LogStatsLevel.EPOCH)
    agg_step = LogStatsAggregator(LogStatsLevel.EPISODE, agg_episode)
    agg_event = LogStatsAggregator(LogStatsLevel.STEP, agg_step)

    no_steps = 5
    no_episodes = 7
    for episode in range(no_episodes):
        for step in range(no_steps):
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=2, attr2=-2)))
            agg_event.add_event(EventRecord(_EventInterface, _EventInterface.event1, dict(attr1=3, attr2=-3)))
            agg_event.reduce()

        episode_stats = agg_step.reduce()
        assert len(episode_stats) == 2
        value1 = episode_stats[(_EventInterface.event1, "attr1", None)]
        value2 = episode_stats[(_EventInterface.event1, "attr2", None)]
        assert value1 == no_steps * 5
        assert value2 == -no_steps * 5

    epoch_stats = agg_episode.reduce()
    assert len(epoch_stats) == 2
    value1 = epoch_stats[(_EventInterface.event1, "attr1", None)]
    value2 = epoch_stats[(_EventInterface.event1, "attr2", None)]
    assert value1 == no_episodes * no_steps * 5
    assert value2 == -no_episodes * no_steps * 5