def test_model_management(tmpdir): results_path = os.path.join(tmpdir, "results") brain_name = "Mock_brain" final_model_path = os.path.join(results_path, brain_name) test_checkpoint_list = [ { "steps": 1, "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"), "reward": 1.312, "creation_time": time.time(), "auxillary_file_paths": [], }, { "steps": 2, "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"), "reward": 1.912, "creation_time": time.time(), "auxillary_file_paths": [], }, { "steps": 3, "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"), "reward": 2.312, "creation_time": time.time(), "auxillary_file_paths": [], }, ] GlobalTrainingStatus.set_parameter_state(brain_name, StatusType.CHECKPOINTS, test_checkpoint_list) new_checkpoint_4 = ModelCheckpoint( 4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time()) ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4) assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 new_checkpoint_5 = ModelCheckpoint( 5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time()) ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4) assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 final_model_path = f"{final_model_path}.nn" final_model_time = time.time() current_step = 6 final_model = ModelCheckpoint(current_step, final_model_path, 3.294, final_model_time) ModelCheckpointManager.track_final_checkpoint(brain_name, final_model) assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][ StatusType.CHECKPOINTS.value] assert check_checkpoints is not None final_model = GlobalTrainingStatus.saved_state[ StatusType.FINAL_CHECKPOINT.value] assert final_model is not None
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary): trainer = create_rl_trainer() mock_policy = mock.Mock() trainer.add_policy("TestBrain", mock_policy) trajectory_queue = AgentManagerQueue("testbrain") policy_queue = AgentManagerQueue("testbrain") trainer.subscribe_trajectory_queue(trajectory_queue) trainer.publish_policy_queue(policy_queue) time_horizon = 10 summary_freq = trainer.trainer_settings.summary_freq checkpoint_interval = trainer.trainer_settings.checkpoint_interval trajectory = mb.make_fake_trajectory( length=time_horizon, observation_specs=create_observation_specs_with_shapes([(1, )]), max_step_complete=True, action_spec=ActionSpec.create_discrete((2, )), ) # Check that we can turn off the trainer and that the buffer is cleared num_trajectories = 5 for _ in range(0, num_trajectories): trajectory_queue.put(trajectory) trainer.advance() # Check that there is stuff in the policy queue policy_queue.get_nowait() # Check that we have called write_summary the appropriate number of times calls = [ mock.call(step) for step in range(summary_freq, num_trajectories * time_horizon, summary_freq) ] mock_write_summary.assert_has_calls(calls, any_order=True) checkpoint_range = range(checkpoint_interval, num_trajectories * time_horizon, checkpoint_interval) calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range] trainer.model_saver.save_checkpoint.assert_has_calls(calls, any_order=True) export_ext = "onnx" add_checkpoint_calls = [ mock.call( trainer.brain_name, ModelCheckpoint( step, f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.{export_ext}", None, mock.ANY, [ f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.pt" ], ), trainer.trainer_settings.keep_checkpoints, ) for step in checkpoint_range ] mock_add_checkpoint.assert_has_calls(add_checkpoint_calls)
def _checkpoint(self) -> ModelCheckpoint: """ Checkpoints the policy associated with this trainer. """ n_policies = len(self.policies.keys()) if n_policies > 1: logger.warning( "Trainer has multiple policies, but default behavior only saves the first." ) checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step) export_ext = "onnx" new_checkpoint = ModelCheckpoint( int(self.step), f"{checkpoint_path}.{export_ext}", self._policy_mean_reward(), time.time(), ) ModelCheckpointManager.add_checkpoint( self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints ) return new_checkpoint