Beispiel #1
0
def test_keys():
    expected = {
        "A": ["B", "C"],
        "B": ["A"],
        "C": ["A"],
    }
    db_path = make_sqlite3_db("test_keys", expected)
    table = Sqlite3LookupTable(db_path)
    assert set(table.keys()) == set(expected.keys())
Beispiel #2
0
def test_sqlite3_lookup_getitem():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db("test_sqlite3_lookup_getitem", expected)
    table = Sqlite3LookupTable(db_path)
    assert table["A"] == expected["A"]
    assert table["B"] == expected["B"]
Beispiel #3
0
def test_len():
    expected = {
        "A": ["B", "C"],
        "B": ["A"],
        "C": ["A"],
    }
    db_path = make_sqlite3_db("test_len", expected)
    table = Sqlite3LookupTable(db_path)
    assert len(table) == len(expected)
Beispiel #4
0
def test_sqlite3_lookup_contains():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db("test_sqlite3_lookup_contains", expected)
    table = Sqlite3LookupTable(db_path)
    assert "A" in table
    assert "B" in table
    assert "C" not in table
Beispiel #5
0
def test_iter():
    expected = {
        "A": ["B", "C"],
        "B": ["A"],
        "C": ["A"],
    }
    db_path = make_sqlite3_db("test_iter", expected)
    table = Sqlite3LookupTable(db_path)
    actual = {k: v for k, v in table}
    assert actual == expected
Beispiel #6
0
def test_len():
    data = {
        "A": "1",
        "B": "2",
        "C": "1",
        "D": "2",
    }
    db_path = make_sqlite3_db("test_len", data)
    table = Sqlite3LookupTable(db_path)
    dataset = sqlite3_dataset.Sqlite3Dataset(table)
    assert len(dataset) == len(data)
Beispiel #7
0
def test_sqlite3_is_preloaded():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db("test_sqlite3_lookup_preload", expected)
    table = Sqlite3LookupTable(db_path)
    assert not table.is_preloaded()
    # Should load the table contents to memory
    table.preload()
    assert table.is_preloaded()
Beispiel #8
0
def test_custom_value_column_name():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db("test_custom_value_column_name",
                              expected,
                              value_column_name="custom")
    table = Sqlite3LookupTable(db_path, value_column_name="custom")
    assert table["A"] == expected["A"]
    assert table["B"] == expected["B"]
    assert "C" not in table
Beispiel #9
0
def test_iter_where():
    db_data = {
        "AA": ["B", "C"],
        "BBBB": ["A"],
        "CC": ["A"],
    }
    db_path = make_sqlite3_db("test_iter_where", db_data)
    table = Sqlite3LookupTable(db_path)
    actual = {k: v for k, v in table.iterate(where="length(key) = 2")}
    expected = {
        "AA": ["B", "C"],
        "CC": ["A"],
    }
    assert actual == expected
Beispiel #10
0
def test_sqlite3_lookup_preload():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db("test_sqlite3_lookup_preload", expected)
    table = Sqlite3LookupTable(db_path)
    # Should load the table contents to memory
    table.preload()
    # remove the in-storage database to test that content was actually loaded
    db_path.unlink()
    assert table["A"] == expected["A"]
    assert table["B"] == expected["B"]
    assert "C" not in table
Beispiel #11
0
def test_getitem():
    expected = {
        "A": "1",
        "B": "2",
        "C": "1",
        "D": "2",
    }
    db_path = make_sqlite3_db("test_getitem", expected)
    table = Sqlite3LookupTable(db_path)
    dataset = sqlite3_dataset.Sqlite3Dataset(table)
    actual = {}
    for idx in range(len(dataset)):
        k, v = dataset[idx]
        actual[k] = v
    assert expected == actual
Beispiel #12
0
def test_sqlite3_lookup_pickle():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db("test_sqlite3_lookup_pickle", expected)
    pickle_path = Path("/tmp/test_sqlite3_lookup_pickle.pkl")
    table = Sqlite3LookupTable(db_path)
    with open(pickle_path, 'wb') as pickle_file:
        pickle.dump(table, pickle_file)
    del table
    with open(pickle_path, 'rb') as pickle_file:
        table = pickle.load(pickle_file)
    assert table["A"] == expected["A"]
    assert table["B"] == expected["B"]
    assert "C" not in table
Beispiel #13
0
def test_subset():
    data = {
        "1": "A",
        "2": "B",
        "3": "C",
        "4": "D",
    }
    db_path = make_sqlite3_db("test_getitem", data)
    table = Sqlite3LookupTable(db_path)
    filter_fn = lambda k: int(k) <= 2
    dataset = sqlite3_dataset.Sqlite3Dataset(table, filter_fn)
    actual = {}
    for idx in range(len(dataset)):
        k, v = dataset[idx]
        actual[k] = v
    expected = {
        "1": "A",
        "2": "B",
    }
    assert expected == actual
Beispiel #14
0
def test_backward_compatable_fallback():
    expected = {
        "A": [1, 2, 3],
        "B": [4, 5, 6],
    }
    db_path = make_sqlite3_db(
        "test_backward_compatable_fallback",
        expected,
    )
    # table set with custom (incorrect) names
    # expected behavior, fall back to defaults
    table = Sqlite3LookupTable(
        db_path,
        table_name="custom_table",
        key_column_name="custom_key",
        value_column_name="custom_value",
    )
    assert table["A"] == expected["A"]
    assert table["B"] == expected["B"]
    assert "C" not in table
Beispiel #15
0
 def __init__(
     self,
     embedding_dir: Path,
     entity_db: Path,
     disable_cache: bool = False,
 ):
     embedding_dir = Path(embedding_dir)
     entity_db = Path(entity_db)
     assert embedding_dir.is_dir(), "Failed to find embedding_dir"
     assert entity_db.is_file(), "Failed to find entities"
     self.entities = Sqlite3LookupTable(entity_db,
                                        disable_cache=disable_cache)
     self._type_part2path = {
         parse_embedding_path(embedding_path): embedding_path
         for embedding_path in embedding_dir.glob("embeddings_*.h5")
     }
     assert any(self._type_part2path), "Failed to find embedding files."
     self._type_part2file_handle = {}
     self._type_part2matrix = {}
     self._use_cache = not disable_cache