コード例 #1
0
def main():
    config = load_config()
    with open(os.path.join(config.cooccurrence_dir, "vocab.pkl"), "rb") as f:
        vocab = pickle.load(f)

    model = GloVe(
        vocab_size=config.vocab_size,
        embedding_size=config.embedding_size,
        x_max=config.x_max,
        alpha=config.alpha
    )
    model.load_state_dict(torch.load(config.output_filepath))
    
    keyed_vectors = KeyedVectors(vector_size=config.embedding_size)
    keyed_vectors.add_vectors(
        keys=[vocab.get_token(index) for index in range(config.vocab_size)],
        weights=(model.weight.weight.detach()
            + model.weight_tilde.weight.detach()).numpy()
    )
    
    print("How similar is man and woman:")
    print(keyed_vectors.similarity("woman", "man"))
    print("How similar is man and apple:")
    print(keyed_vectors.similarity("apple", "man"))
    print("How similar is woman and apple:")
    print(keyed_vectors.similarity("apple", "woman"))
    for word in ["computer", "united", "early"]:
        print(f"Most similar words of {word}:")
        most_similar_words = [word for word, _ in keyed_vectors.similar_by_word(word)]
        print(most_similar_words)
コード例 #2
0
ファイル: test_keyedvectors.py プロジェクト: EricM2/venv
    def test_add_type(self):
        kv = KeyedVectors(2)
        assert kv.vectors.dtype == REAL

        words, vectors = ["a"], np.array([1., 1.], dtype=np.float64).reshape(1, -1)
        kv.add_vectors(words, vectors)

        assert kv.vectors.dtype == REAL
コード例 #3
0
ファイル: test_keyedvectors.py プロジェクト: EricM2/venv
 def test_no_header(self):
     randkv = KeyedVectors(vector_size=100)
     count = 20
     keys = [str(i) for i in range(count)]
     weights = [pseudorandom_weak_vector(randkv.vector_size) for _ in range(count)]
     randkv.add_vectors(keys, weights)
     tmpfiletxt = gensim.test.utils.get_tmpfile("tmp_kv.txt")
     randkv.save_word2vec_format(tmpfiletxt, binary=False, write_header=False)
     reloadtxtkv = KeyedVectors.load_word2vec_format(tmpfiletxt, binary=False, no_header=True)
     self.assertEqual(randkv.index_to_key, reloadtxtkv.index_to_key)
     self.assertTrue((randkv.vectors == reloadtxtkv.vectors).all())
コード例 #4
0
ファイル: test_keyedvectors.py プロジェクト: EricM2/venv
    def test_add_single(self):
        """Test that adding entity in a manual way works correctly."""
        entities = [f'___some_entity{i}_not_present_in_keyed_vectors___' for i in range(5)]
        vectors = [np.random.randn(self.vectors.vector_size) for _ in range(5)]

        # Test `add` on already filled kv.
        for ent, vector in zip(entities, vectors):
            self.vectors.add_vectors(ent, vector)

        for ent, vector in zip(entities, vectors):
            self.assertTrue(np.allclose(self.vectors[ent], vector))

        # Test `add` on empty kv.
        kv = KeyedVectors(self.vectors.vector_size)
        for ent, vector in zip(entities, vectors):
            kv.add_vectors(ent, vector)

        for ent, vector in zip(entities, vectors):
            self.assertTrue(np.allclose(kv[ent], vector))