コード例 #1
0
 def _load_graph(self,
                 policy: TFPolicy,
                 model_path: str,
                 reset_global_steps: bool = False) -> None:
     # This prevents normalizer init up from executing on load
     policy.first_normalization_update = False
     with policy.graph.as_default():
         logger.info(f"Loading model from {model_path}.")
         ckpt = tf.train.get_checkpoint_state(model_path)
         if ckpt is None:
             raise UnityPolicyException(
                 "The model {} could not be loaded. Make "
                 "sure you specified the right "
                 "--run-id and that the previous run you are loading from had the same "
                 "behavior names.".format(model_path))
         if self.tf_saver:
             try:
                 self.tf_saver.restore(policy.sess,
                                       ckpt.model_checkpoint_path)
             except tf.errors.NotFoundError:
                 raise UnityPolicyException(
                     "The model {} was found but could not be loaded. Make "
                     "sure the model is from the same version of ML-Agents, has the same behavior parameters, "
                     "and is using the same trainer configuration as the current run."
                     .format(model_path))
         self._check_model_version(__version__)
         if reset_global_steps:
             policy.set_step(0)
             logger.info(
                 "Starting training from step 0 and saving to {}.".format(
                     self.model_path))
         else:
             logger.info(
                 f"Resuming training from step {policy.get_current_step()}."
             )
コード例 #2
0
 def register(self, module: Union[TFPolicy, TFOptimizer]) -> None:
     if isinstance(module, TFPolicy):
         self._register_policy(module)
     elif isinstance(module, TFOptimizer):
         self._register_optimizer(module)
     else:
         raise UnityPolicyException(
             "Registering Object of unsupported type {} to Saver ".format(
                 type(module)))
コード例 #3
0
ファイル: torch_model_saver.py プロジェクト: terite/HexChess
 def register(self, module: Union[TorchPolicy, TorchOptimizer]) -> None:
     if isinstance(module, TorchPolicy) or isinstance(
             module, TorchOptimizer):
         self.modules.update(module.get_modules())  # type: ignore
     else:
         raise UnityPolicyException(
             "Registering Object of unsupported type {} to ModelSaver ".
             format(type(module)))
     if self.policy is None and isinstance(module, TorchPolicy):
         self.policy = module
         self.exporter = ModelSerializer(self.policy)