def test_limit() -> None:
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with closing(
            _convert_from_text_and_load(
                "tests/test_data/word_vectors.vec", db_path, limit=2
            )
        ) as embeds:
            assert len(embeds) == 2

    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with closing(
            _convert_from_text_and_load(
                "tests/test_data/word_vectors.vec", db_path, limit=1
            )
        ) as embeds:
            assert len(embeds) == 1

    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with pytest.raises(ValueError):
            SqliteWordEmbedding.convert_text_format_to_db(
                "tests/test_data/word_vectors.vec", db_path, limit=0
            )
def test_bad_input_path() -> None:
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with pytest.raises(IOError):
            SqliteWordEmbedding.convert_text_format_to_db(
                "tests/test_data/nonexistent.txt", db_path
            )
def test_bad_float() -> None:
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with pytest.raises(ValueError):
            SqliteWordEmbedding.convert_text_format_to_db(
                "tests/test_data/bad_float.badvec", db_path
            )
def test_load_bad_vocab_size() -> None:
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with pytest.raises(ValueError):
            _convert_from_text_and_load(
                "tests/test_data/bad_vocab_size1.badvec", db_path
            )

    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with pytest.raises(ValueError):
            SqliteWordEmbedding.convert_text_format_to_db(
                "tests/test_data/bad_vocab_size2.badvec", db_path
            )
def test_load_basic() -> None:
    for filename in ("word_vectors.vec", "word_vectors.vec.gz"):
        with TemporaryDirectory() as tmp_dir:
            db_path = os.path.join(tmp_dir, "tmp.sqlite")
            # Batch size intentionally not aligned with vocabulary size
            SqliteWordEmbedding.convert_text_format_to_db(
                _data_path(filename), db_path, batch_size=2
            )
            embed = SqliteWordEmbedding.open(db_path)

            # Test file is in float32
            expected_dtype = np.dtype(np.float32)
            expected_vocab = ["the", "Wikipedia", "article", "\t"]
            expected_vecs = {
                "the": np.array([0.0129, 0.0026, 0.0098], dtype=expected_dtype),
                "Wikipedia": np.array([0.0007, -0.0205, 0.0107], dtype=expected_dtype),
                "article": np.array([0.0050, -0.0114, 0.0150], dtype=expected_dtype),
                "\t": np.array([0.0001, 0.0002, 0.0003], dtype=expected_dtype),
            }

            assert len(embed) == len(expected_vocab)
            assert embed.dim == len(expected_vecs[expected_vocab[0]])
            # Iteration maintains the original text file order
            assert list(embed) == expected_vocab
            assert list(embed.keys()) == expected_vocab
            assert embed.dtype == expected_dtype
            assert embed.name == filename

            for word, vec in expected_vecs.items():
                assert np.array_equal(embed[word], vec)

            for word, vec in embed.items():
                assert np.array_equal(vec, expected_vecs[word])

            # Check order of items matches original order
            assert [item[0] for item in embed.items()] == expected_vocab

            for word in expected_vocab:
                assert word in embed
            assert "jdlkjalkas" not in embed
            with pytest.raises(KeyError):
                embed["asdaskl"]
            embed.close()
Exemple #6
0
def show(action: str, db_path: Union[str, PathLike]) -> None:
    embed = SqliteWordEmbedding.open(db_path)
    if action == DIM:
        print(embed.dim)
    elif action == LENGTH:
        print(len(embed))
    elif action == LENGTH_DIM:
        print(len(embed), embed.dim)
    elif action == VOCAB:
        for word in embed:
            print(word)
    else:
        raise ValueError(f"Unknown action {repr(action)}")
def test_path_types() -> None:
    with TemporaryDirectory() as tmp_dir:
        vec_path_str = "tests/test_data/word_vectors.vec"
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        SqliteWordEmbedding.convert_text_format_to_db(vec_path_str, db_path)

    with TemporaryDirectory() as tmp_dir:
        vec_path_path = Path(vec_path_str)
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        SqliteWordEmbedding.convert_text_format_to_db(vec_path_path, db_path)

    db_path_str = "tests/test_data/word_vectors.sqlite"
    SqliteWordEmbedding.open(db_path_str)

    db_path_path = Path(db_path_str)
    SqliteWordEmbedding.open(db_path_path)
def test_overwrite() -> None:
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        SqliteWordEmbedding.convert_text_format_to_db(
            "tests/test_data/word_vectors.vec", db_path
        )
        # Cannot overwrite without specifying it
        with pytest.raises(IOError):
            SqliteWordEmbedding.convert_text_format_to_db(
                "tests/test_data/word_vectors.vec", db_path
            )
        # Overwriting works
        SqliteWordEmbedding.convert_text_format_to_db(
            "tests/test_data/word_vectors.vec", db_path, overwrite=True
        )
def test_gzipped() -> None:
    # We only need to test the optional argument with True/False as functionality with
    # None is tested in test_load_basic. We're just checking that things don't crash.
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        SqliteWordEmbedding.convert_text_format_to_db(
            _data_path("word_vectors.vec"), db_path, gzipped_input=False
        )
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        SqliteWordEmbedding.convert_text_format_to_db(
            _data_path("word_vectors.vec.gz"), db_path, gzipped_input=True
        )

    # The gzip file is not valid UTF-8, so we get IOError if we force it to be not
    # gzipped.
    with TemporaryDirectory() as tmp_dir:
        db_path = os.path.join(tmp_dir, "tmp.sqlite")
        with pytest.raises(IOError):
            SqliteWordEmbedding.convert_text_format_to_db(
                _data_path("word_vectors.vec.gz"), db_path, gzipped_input=False
            )
def _convert_from_text_and_load(
    text_path, db_path, *args, **kwargs
) -> SqliteWordEmbedding:
    SqliteWordEmbedding.convert_text_format_to_db(text_path, db_path, *args, **kwargs)
    return SqliteWordEmbedding.open(db_path)
def test_load_empty_db() -> None:
    with pytest.raises(IOError):
        with closing(SqliteWordEmbedding.open("tests/test_data/empty_db.sqlite")):
            pass