示例#1
0
def test_model_management(tmpdir):

    results_path = os.path.join(tmpdir, "results")
    brain_name = "Mock_brain"
    final_model_path = os.path.join(results_path, brain_name)
    test_checkpoint_list = [
        {
            "steps": 1,
            "file_path": os.path.join(final_model_path, f"{brain_name}-1.nn"),
            "reward": 1.312,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
        {
            "steps": 2,
            "file_path": os.path.join(final_model_path, f"{brain_name}-2.nn"),
            "reward": 1.912,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
        {
            "steps": 3,
            "file_path": os.path.join(final_model_path, f"{brain_name}-3.nn"),
            "reward": 2.312,
            "creation_time": time.time(),
            "auxillary_file_paths": [],
        },
    ]
    GlobalTrainingStatus.set_parameter_state(brain_name,
                                             StatusType.CHECKPOINTS,
                                             test_checkpoint_list)

    new_checkpoint_4 = ModelCheckpoint(
        4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678,
        time.time())
    ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    new_checkpoint_5 = ModelCheckpoint(
        5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122,
        time.time())
    ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    final_model_path = f"{final_model_path}.nn"
    final_model_time = time.time()
    current_step = 6
    final_model = ModelCheckpoint(current_step, final_model_path, 3.294,
                                  final_model_time)

    ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
    assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

    check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
        StatusType.CHECKPOINTS.value]
    assert check_checkpoints is not None

    final_model = GlobalTrainingStatus.saved_state[
        StatusType.FINAL_CHECKPOINT.value]
    assert final_model is not None
示例#2
0
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary):
    trainer = create_rl_trainer()
    mock_policy = mock.Mock()
    trainer.add_policy("TestBrain", mock_policy)
    trajectory_queue = AgentManagerQueue("testbrain")
    policy_queue = AgentManagerQueue("testbrain")
    trainer.subscribe_trajectory_queue(trajectory_queue)
    trainer.publish_policy_queue(policy_queue)
    time_horizon = 10
    summary_freq = trainer.trainer_settings.summary_freq
    checkpoint_interval = trainer.trainer_settings.checkpoint_interval
    trajectory = mb.make_fake_trajectory(
        length=time_horizon,
        observation_specs=create_observation_specs_with_shapes([(1, )]),
        max_step_complete=True,
        action_spec=ActionSpec.create_discrete((2, )),
    )
    # Check that we can turn off the trainer and that the buffer is cleared
    num_trajectories = 5
    for _ in range(0, num_trajectories):
        trajectory_queue.put(trajectory)
        trainer.advance()
        # Check that there is stuff in the policy queue
        policy_queue.get_nowait()

    # Check that we have called write_summary the appropriate number of times
    calls = [
        mock.call(step) for step in range(summary_freq, num_trajectories *
                                          time_horizon, summary_freq)
    ]
    mock_write_summary.assert_has_calls(calls, any_order=True)

    checkpoint_range = range(checkpoint_interval,
                             num_trajectories * time_horizon,
                             checkpoint_interval)
    calls = [mock.call(trainer.brain_name, step) for step in checkpoint_range]

    trainer.model_saver.save_checkpoint.assert_has_calls(calls, any_order=True)
    export_ext = "onnx"

    add_checkpoint_calls = [
        mock.call(
            trainer.brain_name,
            ModelCheckpoint(
                step,
                f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.{export_ext}",
                None,
                mock.ANY,
                [
                    f"{trainer.model_saver.model_path}{os.path.sep}{trainer.brain_name}-{step}.pt"
                ],
            ),
            trainer.trainer_settings.keep_checkpoints,
        ) for step in checkpoint_range
    ]
    mock_add_checkpoint.assert_has_calls(add_checkpoint_calls)
示例#3
0
 def _checkpoint(self) -> ModelCheckpoint:
     """
     Checkpoints the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step)
     export_ext = "onnx"
     new_checkpoint = ModelCheckpoint(
         int(self.step),
         f"{checkpoint_path}.{export_ext}",
         self._policy_mean_reward(),
         time.time(),
     )
     ModelCheckpointManager.add_checkpoint(
         self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
     )
     return new_checkpoint