Ejemplo n.º 1
0
 def load_word2idx_and_embeddings(self, vocab_file, embedding_path=None, norm=True):
     word2idx_pkl_path, embeddings_pkl_path, _, _ = get_word2idx_pickle_paths(embedding_path=embedding_path,
                                                                              vocab_file=vocab_file,
                                                                              norm=norm)
     if os.path.exists(word2idx_pkl_path) and os.path.exists(embeddings_pkl_path):
         logging.info("loading word2idx from %s", word2idx_pkl_path)
         self.word2idx, self.idx2word = load(word2idx_pkl_path)
         logging.info("loading embeddings from %s", embeddings_pkl_path)
         self.embeddings, self.oov_mask = load(embeddings_pkl_path)
         logging.warning("vocab size %d", len(self.word2idx))
         assert len(self.word2idx) == len(self.embeddings) and len(self.embeddings) == len(self.oov_mask)
     else:
         if not os.path.exists(word2idx_pkl_path):
             logging.info("%s not found", word2idx_pkl_path)
         if not os.path.exists(embeddings_pkl_path):
             logging.info("%s not found", embeddings_pkl_path)
         logging.info("run create_word2idx first! exiting.")
         sys.exit(0)
Ejemplo n.º 2
0
 def load_test_cand_dict(self, path):
     pkl_path = path + ".candict.pkl"
     if os.path.exists(pkl_path):
         logging.info("pkl found! loading %s", pkl_path)
         self.test_cand_dict = load(pkl_path)
     else:
         logging.info("loading test canddict")
         self.test_cand_dict = read_candidates_dict(path)
         save(pkl_path, self.test_cand_dict)
Ejemplo n.º 3
0
 def load_word2idx(self, word2idx_pkl_path=True):
     if os.path.exists(word2idx_pkl_path):
         logging.info("loading word2idx from %s", word2idx_pkl_path)
         word2idx, idx2word = load(word2idx_pkl_path)
         self.word2idx, self.idx2word = word2idx, idx2word
         logging.warning("vocab size %d", len(word2idx))
     else:
         logging.info("%s not found", word2idx_pkl_path)
         logging.info("create word2idx first! exiting.")
         sys.exit(0)
Ejemplo n.º 4
0
 def load_embeddings(self, embeddings_pkl_path):
     if os.path.exists(embeddings_pkl_path):
         logging.info("loading embeddings from %s", embeddings_pkl_path)
         self.embeddings, self.oov_mask = load(embeddings_pkl_path)
         logging.warning("vocab size %d", len(self.word2idx))
         assert len(self.word2idx) == len(self.embeddings) and len(self.embeddings) == len(self.oov_mask)
     else:
         if not os.path.exists(embeddings_pkl_path):
             logging.info("%s not found", embeddings_pkl_path)
         logging.info("run create_word2idx first! exiting.")
         sys.exit(0)
Ejemplo n.º 5
0
 def load_wid2desc(self, path=None):
     pkl_path = path + ".pkl"
     if os.path.exists(pkl_path):
         logging.info("pkl found! loading %s", pkl_path)
         wid2desc = load(pkl_path)
     else:
         logging.info("loading known wids descriptions")
         wid2desc = load_wid2desc(path)
         logging.info("saving pkl wid2desc")
         save(pkl_path, wid2desc)
     self.wid2desc = map_desc(wid2desc, w2i=self.word2idx)
     return self.wid2desc
Ejemplo n.º 6
0
 def load_wid2idx(self, kb_file):
     # kb_file = "data/enwiki/wid_title_mid_types_counts.txt"
     pkl_path = kb_file + ".wid2idx.pkl"
     if os.path.exists(pkl_path):
         logging.info("wid2idx pkl found! loading map %s", pkl_path)
         self.wid2idx, self.idx2wid = load(pkl_path)
     else:
         self.wid2idx, self.idx2wid = {NULL_TITLE_WID: 0}, {0: NULL_TITLE_WID}
         for idx, line in enumerate(open(kb_file)):
             parts = line.strip().split("\t")
             _, wid, _, _, _ = parts
             self.wid2idx[wid] = idx + 1
             self.idx2wid[idx + 1] = wid
         obj = self.wid2idx, self.idx2wid
         save(pkl_path, obj)
     return self.wid2idx, self.idx2wid
Ejemplo n.º 7
0
 def load_coh2idx(self, path):
     pkl_path = path + ".coh.pkl"
     if os.path.exists(pkl_path):
         logging.info("pkl found! loading %s", pkl_path)
         self.coh2idx, self.idx2coh = load(pkl_path)
     else:
         logging.info("loading coh2idx")
         self.coh2idx, self.idx2coh = {OOV_TOKEN: 0}, {0: OOV_TOKEN}
         idx = 1
         for line in open(path):
             parts = line.strip().split("\t")
             if len(parts) != 2:
                 logging.info("bad line %s", parts)
                 continue
             cohstr, cnt = parts
             if cohstr in self.coh2idx:
                 logging.info("duplicate! %s", cohstr)
                 continue
             self.coh2idx[cohstr] = idx
             self.idx2coh[idx] = cohstr
             idx += 1
         obj = self.coh2idx, self.idx2coh
         save(pkl_path, obj)
     logging.info("coh str vocab %d", len(self.coh2idx))