示例#1
0
 def export_model(self, name_behavior_id: str) -> None:
     """
     Exports the model
     """
     policy = self.get_policy(name_behavior_id)
     settings = SerializationSettings(policy.model_path, policy.brain.brain_name)
     export_policy_model(settings, policy.graph, policy.sess)
示例#2
0
def test_load_save(tmp_path):
    path1 = os.path.join(tmp_path, "runid1")
    path2 = os.path.join(tmp_path, "runid2")
    trainer_params = TrainerSettings()
    policy = create_policy_mock(trainer_params, model_path=path1)
    policy.initialize_or_load()
    policy._set_step(2000)

    mock_brain_name = "MockBrain"
    checkpoint_path = f"{policy.model_path}/{mock_brain_name}-2000"
    serialization_settings = SerializationSettings(policy.model_path, mock_brain_name)
    policy.checkpoint(checkpoint_path, serialization_settings)

    assert len(os.listdir(tmp_path)) > 0

    # Try load from this path
    policy2 = create_policy_mock(trainer_params, model_path=path1, load=True, seed=1)
    policy2.initialize_or_load()
    _compare_two_policies(policy, policy2)
    assert policy2.get_current_step() == 2000

    # Try initialize from path 1
    trainer_params.output_path = path2
    trainer_params.init_path = path1
    policy3 = create_policy_mock(trainer_params, model_path=path1, load=False, seed=2)
    policy3.initialize_or_load()

    _compare_two_policies(policy2, policy3)
    # Assert that the steps are 0.
    assert policy3.get_current_step() == 0
示例#3
0
 def _checkpoint(self) -> NNCheckpoint:
     """
     Checkpoints the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     policy = list(self.policies.values())[0]
     model_path = policy.model_path
     settings = SerializationSettings(model_path, self.brain_name)
     checkpoint_path = os.path.join(model_path,
                                    f"{self.brain_name}-{self.step}")
     policy.checkpoint(checkpoint_path, settings)
     new_checkpoint = NNCheckpoint(
         int(self.step),
         f"{checkpoint_path}.nn",
         self._policy_mean_reward(),
         time.time(),
     )
     NNCheckpointManager.add_checkpoint(
         self.brain_name, new_checkpoint,
         self.trainer_settings.keep_checkpoints)
     return new_checkpoint
def test_policy_conversion(dummy_config, tmpdir, rnn, visual, discrete):
    tf.reset_default_graph()
    dummy_config["output_path"] = os.path.join(tmpdir, "test")
    policy = create_policy_mock(dummy_config,
                                use_rnn=rnn,
                                use_discrete=discrete,
                                use_visual=visual)
    policy.save_model(1000)
    settings = SerializationSettings(
        policy.model_path, os.path.join(tmpdir, policy.brain.brain_name))
    export_policy_model(settings, policy.graph, policy.sess)

    # These checks taken from test_barracuda_converter
    assert os.path.isfile(os.path.join(tmpdir, "test.nn"))
    assert os.path.getsize(os.path.join(tmpdir, "test.nn")) > 100
示例#5
0
def test_checkpoint_writes_tf_and_nn_checkpoints(export_policy_model_mock):
    mock_brain = basic_mock_brain()
    test_seed = 4  # moving up in the world
    policy = FakePolicy(test_seed, mock_brain, TrainerSettings(), "output")
    n_steps = 5
    policy.get_current_step = MagicMock(return_value=n_steps)
    policy.saver = MagicMock()
    serialization_settings = SerializationSettings("output",
                                                   mock_brain.brain_name)
    checkpoint_path = f"output/{mock_brain.brain_name}-{n_steps}"
    policy.checkpoint(checkpoint_path, serialization_settings)
    policy.saver.save.assert_called_once_with(policy.sess,
                                              f"{checkpoint_path}.ckpt")
    export_policy_model_mock.assert_called_once_with(checkpoint_path,
                                                     serialization_settings,
                                                     policy.graph, policy.sess)
示例#6
0
def test_policy_conversion(tmpdir, rnn, visual, discrete):
    tf.reset_default_graph()
    dummy_config = TrainerSettings()
    policy = create_policy_mock(
        dummy_config,
        use_rnn=rnn,
        model_path=os.path.join(tmpdir, "test"),
        use_discrete=discrete,
        use_visual=visual,
    )
    settings = SerializationSettings(policy.model_path, "MockBrain")
    checkpoint_path = f"{tmpdir}/MockBrain-1"
    policy.checkpoint(checkpoint_path, settings)

    # These checks taken from test_barracuda_converter
    assert os.path.isfile(checkpoint_path + ".nn")
    assert os.path.getsize(checkpoint_path + ".nn") > 100
示例#7
0
 def save_model(self) -> None:
     """
     Saves the policy associated with this trainer.
     """
     n_policies = len(self.policies.keys())
     if n_policies > 1:
         logger.warning(
             "Trainer has multiple policies, but default behavior only saves the first."
         )
     policy = list(self.policies.values())[0]
     settings = SerializationSettings(policy.model_path, self.brain_name)
     model_checkpoint = self._checkpoint()
     final_checkpoint = attr.evolve(
         model_checkpoint, file_path=f"{policy.model_path}.nn"
     )
     policy.save(policy.model_path, settings)
     NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)