コード例 #1
0
def test_globaltrainingstatus(tmpdir):
    path_dir = os.path.join(tmpdir, "test.json")

    GlobalTrainingStatus.set_parameter_state("Category1",
                                             StatusType.LESSON_NUM, 3)
    GlobalTrainingStatus.save_state(path_dir)

    with open(path_dir) as fp:
        test_json = json.load(fp)

    assert "Category1" in test_json
    assert StatusType.LESSON_NUM.value in test_json["Category1"]
    assert test_json["Category1"][StatusType.LESSON_NUM.value] == 3
    assert "metadata" in test_json

    GlobalTrainingStatus.load_state(path_dir)
    restored_val = GlobalTrainingStatus.get_parameter_state(
        "Category1", StatusType.LESSON_NUM)
    assert restored_val == 3

    # Test unknown categories and status types (keys)
    unknown_category = GlobalTrainingStatus.get_parameter_state(
        "Category3", StatusType.LESSON_NUM)

    class FakeStatusType(Enum):
        NOTAREALKEY = "notarealkey"

    unknown_key = GlobalTrainingStatus.get_parameter_state(
        "Category1", FakeStatusType.NOTAREALKEY)
    assert unknown_category is None
    assert unknown_key is None
コード例 #2
0
 def __init__(
     self,
     settings: Optional[Dict[str, EnvironmentParameterSettings]] = None,
     run_seed: int = -1,
     restore: bool = False,
 ):
     """
     EnvironmentParameterManager manages all the environment parameters of a training
     session. It determines when parameters should change and gives access to the
     current sampler of each parameter.
     :param settings: A dictionary from environment parameter to
     EnvironmentParameterSettings.
     :param run_seed: When the seed is not provided for an environment parameter,
     this seed will be used instead.
     :param restore: If true, the EnvironmentParameterManager will use the
     GlobalTrainingStatus to try and reload the lesson status of each environment
     parameter.
     """
     if settings is None:
         settings = {}
     self._dict_settings = settings
     for parameter_name in self._dict_settings.keys():
         initial_lesson = GlobalTrainingStatus.get_parameter_state(
             parameter_name, StatusType.LESSON_NUM)
         if initial_lesson is None or not restore:
             GlobalTrainingStatus.set_parameter_state(
                 parameter_name, StatusType.LESSON_NUM, 0)
     self._smoothed_values: Dict[str, float] = defaultdict(float)
     for key in self._dict_settings.keys():
         self._smoothed_values[key] = 0.0
     # Update the seeds of the samplers
     self._set_sampler_seeds(run_seed)
コード例 #3
0
def test_model_management(tmpdir):

    results_path = os.path.join(tmpdir, "results")
    brain_name = "Mock_brain"
    final_model_path = os.path.join(results_path, brain_name)
    test_checkpoint_list = [
        {
            "steps": 1,
            "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
            "reward": 1.312,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
        {
            "steps": 2,
            "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
            "reward": 1.912,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
        {
            "steps": 3,
            "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
            "reward": 2.312,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
    ]
    GlobalTrainingStatus.set_parameter_state(brain_name,
                                             StatusType.CHECKPOINTS,
                                             test_checkpoint_list)

    new_checkpoint_4 = ModelCheckpoint(
        4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678,
        time.time())
    ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    new_checkpoint_5 = ModelCheckpoint(
        5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122,
        time.time())
    ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    final_model_path = f"{final_model_path}.nn"
    final_model_time = time.time()
    current_step = 6
    final_model = ModelCheckpoint(current_step, final_model_path, 3.294,
                                  final_model_time)

    ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
        StatusType.CHECKPOINTS.value]
    assert check_checkpoints is not None

    final_model = GlobalTrainingStatus.saved_state[
        StatusType.FINAL_CHECKPOINT.value]
    assert final_model is not None
コード例 #4
0
 def save_model(self) -> None:
     """
     Forwarding call to wrapped trainers save_model.
     """
     GlobalTrainingStatus.set_parameter_state(self.brain_name,
                                              StatusType.ELO,
                                              self.current_elo)
     self.trainer.save_model()
コード例 #5
0
 def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
     checkpoint_list = GlobalTrainingStatus.get_parameter_state(
         behavior_name, StatusType.CHECKPOINTS)
     if not checkpoint_list:
         checkpoint_list = []
         GlobalTrainingStatus.set_parameter_state(behavior_name,
                                                  StatusType.CHECKPOINTS,
                                                  checkpoint_list)
     return checkpoint_list
コード例 #6
0
    def update_lessons(
        self,
        trainer_steps: Dict[str, int],
        trainer_max_steps: Dict[str, int],
        trainer_reward_buffer: Dict[str, List[float]],
    ) -> Tuple[bool, bool]:
        """
        Given progress metrics, calculates if at least one environment parameter is
        in a new lesson and if at least one environment parameter requires the env
        to reset.
        :param trainer_steps: A dictionary from behavior_name to the number of training
        steps this behavior's trainer has performed.
        :param trainer_max_steps: A dictionary from behavior_name to the maximum number
        of training steps this behavior's trainer has performed.
        :param trainer_reward_buffer: A dictionary from behavior_name to the list of
        the most recent episode returns for this behavior's trainer.
        :returns: A tuple of two booleans : (True if any lesson has changed, True if
        environment needs to reset)
        """
        must_reset = False
        updated = False
        for param_name, settings in self._dict_settings.items():
            lesson_num = GlobalTrainingStatus.get_parameter_state(
                param_name, StatusType.LESSON_NUM)
            next_lesson_num = lesson_num + 1
            lesson = settings.curriculum[lesson_num]
            if (lesson.completion_criteria is not None
                    and len(settings.curriculum) > next_lesson_num):
                behavior_to_consider = lesson.completion_criteria.behavior
                if behavior_to_consider in trainer_steps:
                    must_increment, new_smoothing = lesson.completion_criteria.need_increment(
                        float(trainer_steps[behavior_to_consider]) /
                        float(trainer_max_steps[behavior_to_consider]),
                        trainer_reward_buffer[behavior_to_consider],
                        self._smoothed_values[param_name],
                    )
                    self._smoothed_values[param_name] = new_smoothing
                    if must_increment:
                        GlobalTrainingStatus.set_parameter_state(
                            param_name, StatusType.LESSON_NUM, next_lesson_num)
                        new_lesson_name = settings.curriculum[
                            next_lesson_num].name
                        new_lesson_value = settings.curriculum[
                            next_lesson_num].value

                        logger.info(
                            f"Parameter '{param_name}' has been updated to {new_lesson_value}."
                            + f" Now in lesson '{new_lesson_name}'")
                        updated = True
                        if lesson.completion_criteria.require_reset:
                            must_reset = True
        return updated, must_reset
コード例 #7
0
 def track_final_checkpoint(cls, behavior_name: str,
                            final_checkpoint: NNCheckpoint) -> None:
     """
     Ensures number of checkpoints stored is within the max number of checkpoints
     defined by the user and finally stores the information about the final
     model (or intermediate model if training is interrupted).
     :param behavior_name: Behavior name of the model.
     :param final_checkpoint: Checkpoint information for the final model.
     """
     final_model_dict = attr.asdict(final_checkpoint)
     GlobalTrainingStatus.set_parameter_state(behavior_name,
                                              StatusType.FINAL_CHECKPOINT,
                                              final_model_dict)
コード例 #8
0
 def add_checkpoint(cls, behavior_name: str, new_checkpoint: NNCheckpoint,
                    keep_checkpoints: int) -> None:
     """
     Make room for new checkpoint if needed and insert new checkpoint information.
     :param behavior_name: Behavior name for the checkpoint.
     :param new_checkpoint: The new checkpoint to be recorded.
     :param keep_checkpoints: Number of checkpoints to record (user-defined).
     """
     new_checkpoint_dict = attr.asdict(new_checkpoint)
     checkpoints = cls.get_checkpoints(behavior_name)
     checkpoints.append(new_checkpoint_dict)
     cls._cleanup_extra_checkpoints(checkpoints, keep_checkpoints)
     GlobalTrainingStatus.set_parameter_state(behavior_name,
                                              StatusType.CHECKPOINTS,
                                              checkpoints)
コード例 #9
0
    def advance(self, env: EnvManager) -> int:
        # Get steps
        with hierarchical_timer("env_step"):
            num_steps = env.advance()
        # Report current lesson
        if self.meta_curriculum:
            for brain_name, curr in self.meta_curriculum.brains_to_curricula.items(
            ):
                if brain_name in self.trainers:
                    self.trainers[brain_name].stats_reporter.set_stat(
                        "Environment/Lesson", curr.lesson_num)
                    GlobalTrainingStatus.set_parameter_state(
                        brain_name, StatusType.LESSON_NUM, curr.lesson_num)

        for trainer in self.trainers.values():
            if not trainer.threaded:
                with hierarchical_timer("trainer_advance"):
                    trainer.advance()

        return num_steps