Пример #1
0
 def _load_graph(self,
                 policy: TFPolicy,
                 model_path: str,
                 reset_global_steps: bool = False) -> None:
     # This prevents normalizer init up from executing on load
     policy.first_normalization_update = False
     with policy.graph.as_default():
         logger.info(f"Loading model from {model_path}.")
         ckpt = tf.train.get_checkpoint_state(model_path)
         if ckpt is None:
             raise UnityPolicyException(
                 "The model {} could not be loaded. Make "
                 "sure you specified the right "
                 "--run-id and that the previous run you are loading from had the same "
                 "behavior names.".format(model_path))
         if self.tf_saver:
             try:
                 self.tf_saver.restore(policy.sess,
                                       ckpt.model_checkpoint_path)
             except tf.errors.NotFoundError:
                 raise UnityPolicyException(
                     "The model {} was found but could not be loaded. Make "
                     "sure the model is from the same version of ML-Agents, has the same behavior parameters, "
                     "and is using the same trainer configuration as the current run."
                     .format(model_path))
         self._check_model_version(__version__)
         if reset_global_steps:
             policy.set_step(0)
             logger.info(
                 "Starting training from step 0 and saving to {}.".format(
                     self.model_path))
         else:
             logger.info(
                 f"Resuming training from step {policy.get_current_step()}."
             )
Пример #2
0
def test_step_overflow():
    behavior_spec = mb.setup_test_behavior_specs(use_discrete=True,
                                                 use_visual=False,
                                                 vector_action_space=[2],
                                                 vector_obs_space=1)

    policy = TFPolicy(
        0,
        behavior_spec,
        TrainerSettings(network_settings=NetworkSettings(normalize=True)),
        create_tf_graph=False,
    )
    policy.create_input_placeholders()
    policy.initialize()

    policy.set_step(2**31 - 1)
    assert policy.get_current_step() == 2**31 - 1
    policy.increment_step(3)
    assert policy.get_current_step() == 2**31 + 2