def test_keys(): expected = { "A": ["B", "C"], "B": ["A"], "C": ["A"], } db_path = make_sqlite3_db("test_keys", expected) table = Sqlite3LookupTable(db_path) assert set(table.keys()) == set(expected.keys())
def test_sqlite3_lookup_getitem(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db("test_sqlite3_lookup_getitem", expected) table = Sqlite3LookupTable(db_path) assert table["A"] == expected["A"] assert table["B"] == expected["B"]
def test_len(): expected = { "A": ["B", "C"], "B": ["A"], "C": ["A"], } db_path = make_sqlite3_db("test_len", expected) table = Sqlite3LookupTable(db_path) assert len(table) == len(expected)
def test_sqlite3_lookup_contains(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db("test_sqlite3_lookup_contains", expected) table = Sqlite3LookupTable(db_path) assert "A" in table assert "B" in table assert "C" not in table
def test_iter(): expected = { "A": ["B", "C"], "B": ["A"], "C": ["A"], } db_path = make_sqlite3_db("test_iter", expected) table = Sqlite3LookupTable(db_path) actual = {k: v for k, v in table} assert actual == expected
def test_len(): data = { "A": "1", "B": "2", "C": "1", "D": "2", } db_path = make_sqlite3_db("test_len", data) table = Sqlite3LookupTable(db_path) dataset = sqlite3_dataset.Sqlite3Dataset(table) assert len(dataset) == len(data)
def test_sqlite3_is_preloaded(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db("test_sqlite3_lookup_preload", expected) table = Sqlite3LookupTable(db_path) assert not table.is_preloaded() # Should load the table contents to memory table.preload() assert table.is_preloaded()
def test_custom_value_column_name(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db("test_custom_value_column_name", expected, value_column_name="custom") table = Sqlite3LookupTable(db_path, value_column_name="custom") assert table["A"] == expected["A"] assert table["B"] == expected["B"] assert "C" not in table
def test_iter_where(): db_data = { "AA": ["B", "C"], "BBBB": ["A"], "CC": ["A"], } db_path = make_sqlite3_db("test_iter_where", db_data) table = Sqlite3LookupTable(db_path) actual = {k: v for k, v in table.iterate(where="length(key) = 2")} expected = { "AA": ["B", "C"], "CC": ["A"], } assert actual == expected
def test_sqlite3_lookup_preload(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db("test_sqlite3_lookup_preload", expected) table = Sqlite3LookupTable(db_path) # Should load the table contents to memory table.preload() # remove the in-storage database to test that content was actually loaded db_path.unlink() assert table["A"] == expected["A"] assert table["B"] == expected["B"] assert "C" not in table
def test_getitem(): expected = { "A": "1", "B": "2", "C": "1", "D": "2", } db_path = make_sqlite3_db("test_getitem", expected) table = Sqlite3LookupTable(db_path) dataset = sqlite3_dataset.Sqlite3Dataset(table) actual = {} for idx in range(len(dataset)): k, v = dataset[idx] actual[k] = v assert expected == actual
def test_sqlite3_lookup_pickle(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db("test_sqlite3_lookup_pickle", expected) pickle_path = Path("/tmp/test_sqlite3_lookup_pickle.pkl") table = Sqlite3LookupTable(db_path) with open(pickle_path, 'wb') as pickle_file: pickle.dump(table, pickle_file) del table with open(pickle_path, 'rb') as pickle_file: table = pickle.load(pickle_file) assert table["A"] == expected["A"] assert table["B"] == expected["B"] assert "C" not in table
def test_subset(): data = { "1": "A", "2": "B", "3": "C", "4": "D", } db_path = make_sqlite3_db("test_getitem", data) table = Sqlite3LookupTable(db_path) filter_fn = lambda k: int(k) <= 2 dataset = sqlite3_dataset.Sqlite3Dataset(table, filter_fn) actual = {} for idx in range(len(dataset)): k, v = dataset[idx] actual[k] = v expected = { "1": "A", "2": "B", } assert expected == actual
def test_backward_compatable_fallback(): expected = { "A": [1, 2, 3], "B": [4, 5, 6], } db_path = make_sqlite3_db( "test_backward_compatable_fallback", expected, ) # table set with custom (incorrect) names # expected behavior, fall back to defaults table = Sqlite3LookupTable( db_path, table_name="custom_table", key_column_name="custom_key", value_column_name="custom_value", ) assert table["A"] == expected["A"] assert table["B"] == expected["B"] assert "C" not in table
def __init__( self, embedding_dir: Path, entity_db: Path, disable_cache: bool = False, ): embedding_dir = Path(embedding_dir) entity_db = Path(entity_db) assert embedding_dir.is_dir(), "Failed to find embedding_dir" assert entity_db.is_file(), "Failed to find entities" self.entities = Sqlite3LookupTable(entity_db, disable_cache=disable_cache) self._type_part2path = { parse_embedding_path(embedding_path): embedding_path for embedding_path in embedding_dir.glob("embeddings_*.h5") } assert any(self._type_part2path), "Failed to find embedding files." self._type_part2file_handle = {} self._type_part2matrix = {} self._use_cache = not disable_cache