예제 #1
0
def test_load_save(tmp_path):
    path1 = os.path.join(tmp_path, "runid1")
    path2 = os.path.join(tmp_path, "runid2")
    trainer_params = TrainerSettings()
    policy = create_policy_mock(trainer_params)
    model_saver = TFModelSaver(trainer_params, path1)
    model_saver.register(policy)
    model_saver.initialize_or_load(policy)
    policy.set_step(2000)

    mock_brain_name = "MockBrain"
    model_saver.save_checkpoint(mock_brain_name, 2000)
    assert len(os.listdir(tmp_path)) > 0

    # Try load from this path
    model_saver = TFModelSaver(trainer_params, path1, load=True)
    policy2 = create_policy_mock(trainer_params)
    model_saver.register(policy2)
    model_saver.initialize_or_load(policy2)
    _compare_two_policies(policy, policy2)
    assert policy2.get_current_step() == 2000

    # Try initialize from path 1
    trainer_params.init_path = path1
    model_saver = TFModelSaver(trainer_params, path2)
    policy3 = create_policy_mock(trainer_params)
    model_saver.register(policy3)
    model_saver.initialize_or_load(policy3)

    _compare_two_policies(policy2, policy3)
    # Assert that the steps are 0.
    assert policy3.get_current_step() == 0
예제 #2
0
def test_register(tmp_path):
    trainer_params = TrainerSettings()
    model_saver = TFModelSaver(trainer_params, tmp_path)

    opt = mock.Mock(spec=PPOOptimizer)
    model_saver.register(opt)
    assert model_saver.policy is None

    trainer_params = TrainerSettings()
    policy = create_policy_mock(trainer_params)
    model_saver.register(policy)
    assert model_saver.policy is not None
예제 #3
0
def test_checkpoint_conversion(tmpdir, rnn, visual, discrete):
    tf.reset_default_graph()
    dummy_config = TrainerSettings()
    model_path = os.path.join(tmpdir, "Mock_Brain")
    policy = create_policy_mock(
        dummy_config, use_rnn=rnn, use_discrete=discrete, use_visual=visual
    )
    trainer_params = TrainerSettings()
    model_saver = TFModelSaver(trainer_params, model_path)
    model_saver.register(policy)
    model_saver.save_checkpoint("Mock_Brain", 100)
    assert os.path.isfile(model_path + "/Mock_Brain-100.nn")
예제 #4
0
    def test_version_compare(self):
        # Test write_stats
        with self.assertLogs("mlagents.trainers", level="WARNING") as cm:
            trainer_params = TrainerSettings()
            mock_path = tempfile.mkdtemp()
            policy = create_policy_mock(trainer_params)
            model_saver = TFModelSaver(trainer_params, mock_path)
            model_saver.register(policy)

            model_saver._check_model_version(
                "0.0.0")  # This is not the right version for sure
            # Assert that 1 warning has been thrown with incorrect version
            assert len(cm.output) == 1
            model_saver._check_model_version(
                __version__)  # This should be the right version
            # Assert that no additional warnings have been thrown wth correct ver
            assert len(cm.output) == 1