Ejemplo n.º 1
0
 def test_load_without_vocab(self):
     words = ['the', 'of', 'in', 'a', 'to', 'and']
     glove = "test/data_for_tests/embedding/small_static_embedding/glove.6B.50d_test.txt"
     word2vec = "test/data_for_tests/embedding/small_static_embedding/word2vec_test.txt"
     g_m, vocab = EmbedLoader.load_without_vocab(glove)
     self.assertEqual(g_m.shape, (8, 50))
     for word in words:
         self.assertIn(word, vocab)
     w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True)
     self.assertEqual(w_m.shape, (8, 50))
     self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(),
                            8,
                            delta=1e-4)
     for word in words:
         self.assertIn(word, vocab)
     # no unk
     w_m, vocab = EmbedLoader.load_without_vocab(word2vec,
                                                 normalize=True,
                                                 unknown=None)
     self.assertEqual(w_m.shape, (7, 50))
     self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(),
                            7,
                            delta=1e-4)
     for word in words:
         self.assertIn(word, vocab)
Ejemplo n.º 2
0
 def test_load_with_vocab(self):
     vocab = Vocabulary()
     glove = "test/data_for_tests/glove.6B.50d_test.txt"
     word2vec = "test/data_for_tests/word2vec_test.txt"
     vocab.add_word('the')
     vocab.add_word('none')
     g_m = EmbedLoader.load_with_vocab(glove, vocab)
     self.assertEqual(g_m.shape, (4, 50))
     w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
     self.assertEqual(w_m.shape, (4, 50))
     self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)
Ejemplo n.º 3
0
def process_data_1(embed_file, cws_train):
    embed, vocab = EmbedLoader.load_without_vocab(embed_file)
    time.sleep(1)  # 测试是否通过读取cache获得结果
    with open(cws_train, 'r', encoding='utf-8') as f:
        d = DataSet()
        for line in f:
            line = line.strip()
            if len(line) > 0:
                d.append(Instance(raw=line))
    return embed, vocab, d
Ejemplo n.º 4
0
def pretrain_embedding(args):
    if not os.path.exists(args.prepare_dir):
        os.makedirs(args.prepare_dir)
    text_data = TextData()
    with open(os.path.join(args.vocab_dir, args.vocab_data), 'rb') as fin:
        text_data = pickle.load(fin)

    if args.pretrain_model == 'word2vec':
        w2v_model_path = word2vec_pretrain(args, text_data.train_set['text'])
        print("Load the glove model from {0}.".format(w2v_model_path))
        loader = EmbedLoader()
        vocab = text_data.vocab
        pretrained = loader.load_with_vocab(w2v_model_path, vocab)
        print(pretrained)
        save_path = os.path.join(args.prepare_dir, 'w2v_embeds.pkl')
        with open(save_path, 'wb') as fout:
            pickle.dump(pretrained, fout)
        print("Building word2vec done.Matrix saved in {0}".format(save_path))

    elif args.pretrain_model == 'glove':
        print("Using Glove...")
        glove_model_path = os.path.join(args.prepare_dir, 'glove_model.txt')
        print("Load the glove model from {0}.".format(glove_model_path))
        loader = EmbedLoader()
        vocab = text_data.vocab
        pretrained = loader.load_with_vocab(glove_model_path, vocab)
        print(pretrained)
        save_path = os.path.join(args.prepare_dir, 'glove_embeds.pkl')
        with open(save_path, 'wb') as fout:
            pickle.dump(pretrained, fout)
        print("Building Glove done.Matrix saved in {0}".format(save_path))
    elif args.pretrain_model == 'glove2wv':
        print("Using Glove trained with word2vec...")
        glove_model_path = glove2wv_pretrain(args, text_data.train_set['text'])
        print("Load the glove model from {0}.".format(glove_model_path))
        loader = EmbedLoader()
        vocab = text_data.vocab
        pretrained = loader.load_with_vocab(glove_model_path, vocab)
        print(pretrained)
        save_path = os.path.join(args.prepare_dir, 'glove2wv_embeds.pkl')
        with open(save_path, 'wb') as fout:
            pickle.dump(pretrained, fout)
        print("Building Glove done.Matrix saved in {0}".format(save_path))

    else:
        print("No pretrain model will be used.")
Ejemplo n.º 5
0
def embedding_load_with_cache(emb_file, cache_dir, vocab, **kwargs):
    def match_cache(file, cache_dir):
        md5 = md5_for_file(file)
        cache_files = os.listdir(cache_dir)
        for fn in cache_files:
            if md5 in fn.split("-")[-1]:
                return os.path.join(cache_dir, fn), True
        return (
            "{}-{}.pkl".format(os.path.join(cache_dir, os.path.basename(file)),
                               md5),
            False,
        )

    def get_cache(file):
        if not os.path.exists(file):
            return None
        with open(file, "rb") as f:
            emb = pickle.load(f)
        return emb

    os.makedirs(cache_dir, exist_ok=True)
    cache_fn, match = match_cache(emb_file, cache_dir)
    if not match:
        print("cache missed, re-generating cache at {}".format(cache_fn))
        emb, ori_vocab = EmbedLoader.load_without_vocab(emb_file,
                                                        padding=None,
                                                        unknown=None,
                                                        normalize=False)
        with open(cache_fn, "wb") as f:
            pickle.dump((emb, ori_vocab), f)

    else:
        print("cache matched at {}".format(cache_fn))

    # use cache
    print("loading embeddings ...")
    emb = get_cache(cache_fn)
    assert emb is not None
    return embedding_match_vocab(vocab, emb[0], emb[1], **kwargs)