def test_policy_loading_load_returns_wrong_type(tmp_path: Path): original_policy_ensemble = PolicyEnsemble([LoadReturnsWrongTypePolicy()]) original_policy_ensemble.train([], None, RegexInterpreter()) original_policy_ensemble.persist(str(tmp_path)) with pytest.raises(Exception): PolicyEnsemble.load(str(tmp_path))
def test_policy_loading_simple(tmp_path: Path): original_policy_ensemble = PolicyEnsemble([WorkingPolicy()]) original_policy_ensemble.train([], None, RegexInterpreter()) original_policy_ensemble.persist(str(tmp_path)) loaded_policy_ensemble = PolicyEnsemble.load(str(tmp_path)) assert original_policy_ensemble.policies == loaded_policy_ensemble.policies
def test_policy_loading_simple(tmpdir): original_policy_ensemble = PolicyEnsemble([WorkingPolicy()]) original_policy_ensemble.train([], None) original_policy_ensemble.persist(str(tmpdir)) loaded_policy_ensemble = PolicyEnsemble.load(str(tmpdir)) assert original_policy_ensemble.policies == loaded_policy_ensemble.policies
def test_policy_loading_load_returns_wrong_type(tmpdir): original_policy_ensemble = PolicyEnsemble([LoadReturnsWrongTypePolicy()]) original_policy_ensemble.train([], None) original_policy_ensemble.persist(str(tmpdir)) with pytest.raises(Exception): PolicyEnsemble.load(str(tmpdir))
def test_policy_loading_no_kwargs_with_no_context(tmp_path: Path, capsys: CaptureFixture): original_policy_ensemble = PolicyEnsemble([PolicyWithoutLoadKwargs()]) original_policy_ensemble.train([], None, RegexInterpreter()) original_policy_ensemble.persist(str(tmp_path)) with pytest.warns(FutureWarning): PolicyEnsemble.load(str(tmp_path))
def test_policy_loading_no_kwargs_with_context(tmp_path: Path): original_policy_ensemble = PolicyEnsemble([PolicyWithoutLoadKwargs()]) original_policy_ensemble.train([], None, RegexInterpreter()) original_policy_ensemble.persist(str(tmp_path)) with pytest.raises(UnsupportedDialogueModelError) as execinfo: PolicyEnsemble.load(str(tmp_path), new_config={"policies": [{}]}) assert "`PolicyWithoutLoadKwargs.load` does not accept `**kwargs`" in str( execinfo.value)
def test_policy_loading_load_returns_none(tmp_path: Path, caplog: LogCaptureFixture): original_policy_ensemble = PolicyEnsemble([LoadReturnsNonePolicy()]) original_policy_ensemble.train([], None, RegexInterpreter()) original_policy_ensemble.persist(str(tmp_path)) with caplog.at_level(logging.WARNING): ensemble = PolicyEnsemble.load(str(tmp_path)) assert (caplog.records.pop().msg == "Failed to load policy tests.core.test_ensemble." "LoadReturnsNonePolicy: load returned None") assert len(ensemble.policies) == 0