Пример #1
0
 def _checkpoint(self) -> NNCheckpoint:
     """
     Checkpoints the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     policy = list(self.policies.values())[0]
     model_path = policy.model_path
     settings = SerializationSettings(model_path, self.brain_name)
     checkpoint_path = os.path.join(model_path,
                                    f"{self.brain_name}-{self.step}")
     policy.checkpoint(checkpoint_path, settings)
     new_checkpoint = NNCheckpoint(
         int(self.step),
         f"{checkpoint_path}.nn",
         self._policy_mean_reward(),
         time.time(),
     )
     NNCheckpointManager.add_checkpoint(
         self.brain_name, new_checkpoint,
         self.trainer_settings.keep_checkpoints)
     return new_checkpoint
def test_model_management(tmpdir):

    results_path = os.path.join(tmpdir, "results")
    brain_name = "Mock_brain"
    final_model_path = os.path.join(results_path, brain_name)
    test_checkpoint_list = [
        {
            "steps": 1,
            "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
            "reward": 1.312,
            "creation_time": time.time(),
        },
        {
            "steps": 2,
            "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
            "reward": 1.912,
            "creation_time": time.time(),
        },
        {
            "steps": 3,
            "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
            "reward": 2.312,
            "creation_time": time.time(),
        },
    ]
    GlobalTrainingStatus.set_parameter_state(brain_name,
                                             StatusType.CHECKPOINTS,
                                             test_checkpoint_list)

    new_checkpoint_4 = NNCheckpoint(
        4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678,
        time.time())
    NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
    assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4

    new_checkpoint_5 = NNCheckpoint(
        5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122,
        time.time())
    NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
    assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4

    final_model_path = f"{final_model_path}.nn"
    final_model_time = time.time()
    current_step = 6
    final_model = NNCheckpoint(current_step, final_model_path, 3.294,
                               final_model_time)

    NNCheckpointManager.track_final_checkpoint(brain_name, final_model)
    assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4

    check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
        StatusType.CHECKPOINTS.value]
    assert check_checkpoints is not None

    final_model = GlobalTrainingStatus.saved_state[
        StatusType.FINAL_CHECKPOINT.value]
    assert final_model is not None
Пример #3
0
 def save_model(self) -> None:
     """
     Saves the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     policy = list(self.policies.values())[0]
     settings = SerializationSettings(policy.model_path, self.brain_name)
     model_checkpoint = self._checkpoint()
     final_checkpoint = attr.evolve(
         model_checkpoint, file_path=f"{policy.model_path}.nn"
     )
     policy.save(policy.model_path, settings)
     NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
Пример #4
0
    def save_model(self) -> None:
        """
        Saves the policy associated with this trainer.
        """
        n_policies = len(self.policies.keys())
        if n_policies > 1:
            logger.warning(
                "Trainer has multiple policies, but default behavior only saves the first."
            )
        model_checkpoint = self._checkpoint()

        # Copy the checkpointed model files to the final output location
        self.saver.copy_final_model(model_checkpoint.file_path)

        final_checkpoint = attr.evolve(model_checkpoint,
                                       file_path=f"{self.saver.model_path}.nn")
        NNCheckpointManager.track_final_checkpoint(self.brain_name,
                                                   final_checkpoint)
Пример #5
0
 def _checkpoint(self) -> NNCheckpoint:
     """
     Checkpoints the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     checkpoint_path = self.saver.save_checkpoint(self.brain_name,
                                                  self.step)
     new_checkpoint = NNCheckpoint(
         int(self.step),
         f"{checkpoint_path}.nn",
         self._policy_mean_reward(),
         time.time(),
     )
     NNCheckpointManager.add_checkpoint(
         self.brain_name, new_checkpoint,
         self.trainer_settings.keep_checkpoints)
     return new_checkpoint
Пример #6
0
    def save_model(self) -> None:
        """
        Saves the policy associated with this trainer.
        """
        n_policies = len(self.policies.keys())
        if n_policies > 1:
            logger.warning(
                "Trainer has multiple policies, but default behavior only saves the first."
            )
        elif n_policies == 0:
            logger.warning("Trainer has no policies, not saving anything.")
            return

        model_checkpoint = self._checkpoint()
        self.model_saver.copy_final_model(model_checkpoint.file_path)
        export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
        final_checkpoint = attr.evolve(
            model_checkpoint,
            file_path=f"{self.model_saver.model_path}.{export_ext}")
        NNCheckpointManager.track_final_checkpoint(self.brain_name,
                                                   final_checkpoint)
Пример #7
0
 def _checkpoint(self) -> NNCheckpoint:
     """
     Checkpoints the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     checkpoint_path = self.model_saver.save_checkpoint(
         self.brain_name, self.step)
     export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
     new_checkpoint = NNCheckpoint(
         int(self.step),
         f"{checkpoint_path}.{export_ext}",
         self._policy_mean_reward(),
         time.time(),
     )
     NNCheckpointManager.add_checkpoint(
         self.brain_name, new_checkpoint,
         self.trainer_settings.keep_checkpoints)
     return new_checkpoint