def train_node2vec(paths, params): dump_process_pkl = paths.dump_process dump_context_dict = paths.dump_context_dict dump_context_list = paths.dump_context_list dump_walks = paths.dump_walks save_model_path = paths.node2vec_base embedding_txt = paths.embedding_text embedding_temp = paths.embedding_temp embedding = paths.embedding mesh_graph_file = paths.MeSH_graph_disease if not params.randomize: np.random.seed(5) torch.manual_seed(5) random.seed(5) # ----------- Random walk -------------------- directed_graph = False if not os.path.exists(dump_walks): num_walks = 30 walk_length = 8 nx_G = read_graph(mesh_graph_file, directed_graph) G = Graph(nx_G, is_directed=directed_graph, p=params.p, q=params.q) G.preprocess_transition_probs() walks = G.simulate_walks(num_walks, walk_length) with open(dump_walks, 'wb') as f: pickle.dump(walks, f) else: with open(dump_walks, 'rb') as f: walks = pickle.load(f) # ---------- train SkipGram ----------------- epochs = params.epochs batch_size = params.batch_size window = params.window num_neg_sample = params.num_neg_sample writer = SummaryWriter() if os.path.exists(dump_process_pkl): with open(dump_process_pkl, 'rb') as f: vocab = pickle.load(f) else: vocab = Vocabulary(lower=False) vocab.add_documents(walks) vocab.build() with open(dump_process_pkl, 'wb') as f: pickle.dump(vocab, f) # use transformation only once, i.e either during creating the context dict and list or during training if not os.path.exists(dump_context_dict): l, d = multiprocess(walks, window=window, transform=vocab.doc2id) with open(dump_context_dict, 'wb') as f: pickle.dump(d, f) with open(dump_context_list, 'wb') as f: pickle.dump(l, f) else: with open(dump_context_dict, 'rb') as f: d = pickle.load(f) with open(dump_context_list, 'rb') as f: l = pickle.load(f) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # here transformation is required we will directly sample the index sample_table = negative_sampling_table(vocab.token_counter(), transform=vocab.token_to_id) neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample)) context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None) context_dataloader = DataLoader(context_data, batch_size=batch_size, shuffle=True, num_workers=6) model_embedding = SkipGram(len(vocab.vocab), embedding_size=1024) model_embedding.to(device) optimizer_embedding = torch.optim.SparseAdam(model_embedding.parameters(), lr=0.005) train(model_embedding, optimizer_embedding, context_dataloader, epochs, device, neg_sample, n_sample=num_neg_sample, transform=None, writer=writer, save_path=save_model_path, l=l, d=d, vocab=vocab, batch_size=batch_size) word_embeddings = (model_embedding.out_embedding.weight.data + model_embedding.in_embedding.weight.data) / 2 word_embeddings = word_embeddings.cpu().numpy() sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1]) with open(embedding_txt, 'w') as f: for idx, item in enumerate(sorted_vocab_tuple): if item[0] == '\n': continue f.write(item[0] + ' ' + ' '.join([str(i) for i in word_embeddings[idx]]) + '\n') glove_file = datapath(embedding_txt) temp_file = get_tmpfile(embedding_temp) _ = glove2word2vec(glove_file, temp_file) wv = KeyedVectors.load_word2vec_format(temp_file) wv.save(embedding) writer.close() # if __name__ == '__main__': # base_path = '/media/druv022/Data2/Final' # paths = Paths(base_path, node2vec_type='1') # train_node2vec(paths)
def train_node2vec(paths, params): dump_process_pkl = paths.dump_process dump_context_dict = paths.dump_context_dict dump_context_list = paths.dump_context_list dump_walks = paths.dump_walks save_model_path = paths.node2vec_base embedding_txt = paths.embedding_text embedding_temp = paths.embedding_temp embedding = paths.embedding mesh_graph_file = paths.MeSH_graph_disease if not params.randomize: np.random.seed(5) torch.manual_seed(5) random.seed(5) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") writer = SummaryWriter() # ----------- Random walk -------------------- directed_graph = False if not os.path.exists(dump_walks): num_walks = 30 walk_length = 10 nx_G = read_graph(mesh_graph_file, directed_graph) G = Graph(nx_G, is_directed=directed_graph, p=params.p, q=params.q) G.preprocess_transition_probs() walks = G.simulate_walks(num_walks, walk_length) with open(dump_walks, 'wb') as f: pickle.dump(walks, f) else: with open(dump_walks, 'rb') as f: walks = pickle.load(f) if os.path.exists(dump_process_pkl): with open(dump_process_pkl, 'rb') as f: vocab = pickle.load(f) else: vocab = Vocabulary(lower=False) vocab.add_documents(walks) vocab.build() with open(dump_process_pkl, 'wb') as f: pickle.dump(vocab, f) # ---------- build embedding model ---------- mesh_file = paths.MeSH_file ELMO_folder = paths.elmo_folder options_file = paths.elmo_options weight_file = paths.elmo_weights elmo = Elmo(options_file, weight_file, 2, dropout=0) elmo.to(device) mesh_graph = nx.read_gpickle(mesh_graph_file) mesh_graph = mesh_graph.to_undirected() mesh_dict = read_mesh_file(mesh_file) # Get the list of nodes (idx 0 is '<pad>') node_list = list(vocab.vocab.keys()) # create weight matrix by using node_list order(which correspond to original vocab index order) elmo_embedding_dim = 1024 if not os.path.exists(os.path.join(ELMO_folder, 'elmo_weights')): weight_list = [] for idx, i in enumerate(node_list): if i in mesh_dict: node_idx = vocab.token_to_id(i) scope_note = mesh_dict[i].scope_note character_ids = batch_to_ids(scope_note).to(device) elmo_embeddings = elmo(character_ids) embeddings = elmo_embeddings['elmo_representations'][0] mask = elmo_embeddings['mask'] embeddings = embeddings * mask.unsqueeze(2).expand( mask.shape[0], mask.shape[1], embeddings.shape[2]).float() embeddings = embeddings.mean(dim=0).mean(dim=0) # average weight_list.append(embeddings.cpu()) else: weight_list.append(torch.zeros(elmo_embedding_dim)) with open(os.path.join(ELMO_folder, 'elmo_weights'), 'wb') as f: pickle.dump(weight_list, f) else: with open(os.path.join(ELMO_folder, 'elmo_weights'), 'rb') as f: weight_list = pickle.load(f) weight = torch.stack(weight_list, dim=0) # ---------- train SkipGram ----------------- epochs = params.epochs batch_size = params.batch_size window = params.window num_neg_sample = params.num_neg_sample writer = SummaryWriter() # use transformation only once, i.e either during creating the context dict and list or during training if not os.path.exists(dump_context_dict): l, d = multiprocess(walks, window=window, transform=vocab.doc2id) with open(dump_context_dict, 'wb') as f: pickle.dump(d, f) with open(dump_context_list, 'wb') as f: pickle.dump(l, f) else: with open(dump_context_dict, 'rb') as f: d = pickle.load(f) with open(dump_context_list, 'rb') as f: l = pickle.load(f) # here transformation is required we will directly sample the index sample_table = negative_sampling_table(vocab.token_counter(), transform=vocab.token_to_id) neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample)) context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None) context_dataloader = DataLoader(context_data, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6) model_embedding = SkipGramModified(len(vocab.vocab), embedding_size=elmo_embedding_dim, weight=weight) model_embedding.to(device) optimizer_FC = torch.optim.Adam(list(model_embedding.parameters()), lr=0.005) #+list(model_fc.parameters() train(model_embedding, optimizer_FC, context_dataloader, epochs, device, neg_sample, n_sample=num_neg_sample, writer=writer, save_path=save_model_path, l=l, d=d, vocab=vocab, batch_size=batch_size) node_idx = [] for item in node_list: node_idx.append(vocab.token_to_id(item)) x = torch.tensor(node_idx, device=device) y = torch.zeros(x.shape, device=device) z = torch.zeros(x.shape, device=device) x, y, z = model_embedding(x, y, z) word_embeddings = x.cpu().detach().numpy() sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1]) with open(embedding_txt, 'w') as f: for idx, item in enumerate(sorted_vocab_tuple): if item[0] == '\n': continue f.write(item[0] + ' ' + ' '.join([str(i) for i in word_embeddings[idx]]) + '\n') glove_file = datapath(embedding_txt) temp_file = get_tmpfile(embedding_temp) _ = glove2word2vec(glove_file, temp_file) wv = KeyedVectors.load_word2vec_format(temp_file) wv.save(embedding) writer.close()
def main(): # Update path training_data = r'----------------/Data/Skipgram/hansards/training.en' dump_process_pkl = r'----------------/Data/Skipgram/hansards/processed_en_w.pkl' dump_context_dict = r'----------------/Data/Skipgram/hansards/context_dict_w.pkl' dump_context_list = r'----------------/Data/Skipgram/hansards/context_list_w.pkl' save_model_path = r'----------------/Data/Skipgram/hansards' embedding_txt = r'----------------/Data/Skipgram/hansards/embedding.txt' embedding_temp = r'----------------/Data/Skipgram/hansards/embedding_temp.txt' epochs = 20 batch_size = 2**10 window = 5 num_neg_sample = 5 writer = SummaryWriter() stopwords = set(stopwords.words('english')) with open(training_data, 'r') as f: data = f.readlines() data = [line.replace('\n', '').split(' ') for line in data] data = [[word for word in line if word not in stopwords] for line in data] if os.path.exists(dump_process_pkl): with open(dump_process_pkl, 'rb') as f: vocab = pickle.load(f) else: vocab = Vocabulary() vocab.add_documents(data) vocab.build() with open(dump_process_pkl, 'wb') as f: pickle.dump(vocab, f) # use transformation only once, i.e either during creating the context dict and list or during training if not os.path.exists(dump_context_dict): l, d = multiprocess(data, window=window, transform=vocab.doc2id) with open(dump_context_dict, 'wb') as f: pickle.dump(d, f) with open(dump_context_list, 'wb') as f: pickle.dump(l, f) else: with open(dump_context_dict, 'rb') as f: d = pickle.load(f) with open(dump_context_list, 'rb') as f: l = pickle.load(f) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # here transformation is required we will directly sample the index sample_table = negative_sampling_table(vocab.token_counter(), transform=vocab.token_to_id) neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample)) context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None) context_dataloader = DataLoader(context_data, batch_size=batch_size, shuffle=True, num_workers=6) model_embedding = SkipGram(len(vocab.vocab), embedding_size=200) model_embedding.load_state_dict( torch.load(os.path.join(save_model_path, 'sk_model5_5.pkl'))) model_embedding.to(device) optimizer_embedding = torch.optim.SparseAdam(model_embedding.parameters(), lr=0.005) train(model_embedding, optimizer_embedding, context_dataloader, epochs, device, neg_sample, n_sample=num_neg_sample, save_path=save_model_path) word_embeddings = (model_embedding.out_embedding.weight.data + model_embedding.in_embedding.weight.data) / 2 word_embeddings = word_embeddings.cpu().numpy() sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1]) with open(embedding_txt, 'w') as f: for idx, item in enumerate(sorted_vocab_tuple): if item[0] == '\n': continue f.write(item[0] + ' ' + ' '.join([str(i) for i in word_embeddings[idx]]) + '\n') glove_file = datapath(embedding_txt) temp_file = get_tmpfile(embedding_temp) _ = glove2word2vec(glove_file, temp_file) wv = KeyedVectors.load_word2vec_format(temp_file) result = wv.most_similar(positive=['woman', 'king'], negative=['man']) print("{}: {:.4f}".format(*result[0])) writer.close()