def _load_graph(self, policy: TFPolicy, model_path: str, reset_global_steps: bool = False) -> None: # This prevents normalizer init up from executing on load policy.first_normalization_update = False with policy.graph.as_default(): logger.info(f"Loading model from {model_path}.") ckpt = tf.train.get_checkpoint_state(model_path) if ckpt is None: raise UnityPolicyException( "The model {} could not be loaded. Make " "sure you specified the right " "--run-id and that the previous run you are loading from had the same " "behavior names.".format(model_path)) if self.tf_saver: try: self.tf_saver.restore(policy.sess, ckpt.model_checkpoint_path) except tf.errors.NotFoundError: raise UnityPolicyException( "The model {} was found but could not be loaded. Make " "sure the model is from the same version of ML-Agents, has the same behavior parameters, " "and is using the same trainer configuration as the current run." .format(model_path)) self._check_model_version(__version__) if reset_global_steps: policy.set_step(0) logger.info( "Starting training from step 0 and saving to {}.".format( self.model_path)) else: logger.info( f"Resuming training from step {policy.get_current_step()}." )
def test_step_overflow(): behavior_spec = mb.setup_test_behavior_specs(use_discrete=True, use_visual=False, vector_action_space=[2], vector_obs_space=1) policy = TFPolicy( 0, behavior_spec, TrainerSettings(network_settings=NetworkSettings(normalize=True)), create_tf_graph=False, ) policy.create_input_placeholders() policy.initialize() policy.set_step(2**31 - 1) assert policy.get_current_step() == 2**31 - 1 policy.increment_step(3) assert policy.get_current_step() == 2**31 + 2