def test_retrieve_none(self, fake_sqlalchemy_engine): identifiers = [(-12, -1)] expected_len = 0 res = retrieve_mining_cache( identifiers, ["data_and_models/models/ner_er/model-organism"], fake_sqlalchemy_engine, ) assert isinstance(res, pd.DataFrame) assert len(res) == expected_len
def test_retrieve_some(self, fake_sqlalchemy_engine, test_parameters, etypes, entity_types): identifiers = [(1, -1), (2, 1)] if etypes == "ORGANISM": expected_len = 3 # 2 for the article and 1 for the paragraph else: expected_len = 0 res = retrieve_mining_cache(identifiers, [etypes], fake_sqlalchemy_engine) assert isinstance(res, pd.DataFrame) assert len(res) == expected_len assert set(res["article_id"].unique()) == ({1, 2} if etypes == "ORGANISM" else set())
def test_retrieve_all(self, fake_sqlalchemy_engine, test_parameters, entity_types): identifiers = [(i + 1, -1) for i in range(test_parameters["n_articles"])] expected_len = (test_parameters["n_articles"] * test_parameters["n_sections_per_article"] * test_parameters["n_entities_per_section"]) res = retrieve_mining_cache( identifiers, entity_types, fake_sqlalchemy_engine, ) assert isinstance(res, pd.DataFrame) assert len(res) == expected_len
def test_retrieve_some(self, fake_sqlalchemy_engine, test_parameters, mining_model): identifiers = [(1, -1), (2, 1)] if mining_model == "data_and_models/models/ner_er/model1": expected_len = ( 1 * test_parameters["n_sections_per_article"] * test_parameters["n_entities_per_section"] + 1 * 1 * test_parameters["n_entities_per_section"] ) else: expected_len = 0 res = retrieve_mining_cache(identifiers, [mining_model], fake_sqlalchemy_engine) assert isinstance(res, pd.DataFrame) assert len(res) == expected_len assert set(res["article_id"].unique()) == ( {1, 2} if mining_model == "data_and_models/models/ner_er/model1" else set() )