def train_node2vec(paths, params):
    dump_process_pkl = paths.dump_process
    dump_context_dict = paths.dump_context_dict
    dump_context_list = paths.dump_context_list
    dump_walks = paths.dump_walks
    save_model_path = paths.node2vec_base
    embedding_txt = paths.embedding_text
    embedding_temp = paths.embedding_temp
    embedding = paths.embedding
    mesh_graph_file = paths.MeSH_graph_disease

    if not params.randomize:
        np.random.seed(5)
        torch.manual_seed(5)
        random.seed(5)

    # ----------- Random walk --------------------
    directed_graph = False

    if not os.path.exists(dump_walks):
        num_walks = 30
        walk_length = 8
        nx_G = read_graph(mesh_graph_file, directed_graph)
        G = Graph(nx_G, is_directed=directed_graph, p=params.p, q=params.q)
        G.preprocess_transition_probs()
        walks = G.simulate_walks(num_walks, walk_length)
        with open(dump_walks, 'wb') as f:
            pickle.dump(walks, f)
    else:
        with open(dump_walks, 'rb') as f:
            walks = pickle.load(f)

    # ---------- train SkipGram -----------------
    epochs = params.epochs
    batch_size = params.batch_size
    window = params.window
    num_neg_sample = params.num_neg_sample
    writer = SummaryWriter()

    if os.path.exists(dump_process_pkl):
        with open(dump_process_pkl, 'rb') as f:
            vocab = pickle.load(f)
    else:
        vocab = Vocabulary(lower=False)
        vocab.add_documents(walks)
        vocab.build()

        with open(dump_process_pkl, 'wb') as f:
            pickle.dump(vocab, f)

    # use transformation only once, i.e either during creating the context dict and list or during training
    if not os.path.exists(dump_context_dict):
        l, d = multiprocess(walks, window=window, transform=vocab.doc2id)
        with open(dump_context_dict, 'wb') as f:
            pickle.dump(d, f)
        with open(dump_context_list, 'wb') as f:
            pickle.dump(l, f)
    else:
        with open(dump_context_dict, 'rb') as f:
            d = pickle.load(f)
        with open(dump_context_list, 'rb') as f:
            l = pickle.load(f)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # here transformation is required we will directly sample the index
    sample_table = negative_sampling_table(vocab.token_counter(),
                                           transform=vocab.token_to_id)
    neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample))

    context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None)
    context_dataloader = DataLoader(context_data,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=6)

    model_embedding = SkipGram(len(vocab.vocab), embedding_size=1024)
    model_embedding.to(device)
    optimizer_embedding = torch.optim.SparseAdam(model_embedding.parameters(),
                                                 lr=0.005)

    train(model_embedding,
          optimizer_embedding,
          context_dataloader,
          epochs,
          device,
          neg_sample,
          n_sample=num_neg_sample,
          transform=None,
          writer=writer,
          save_path=save_model_path,
          l=l,
          d=d,
          vocab=vocab,
          batch_size=batch_size)
    word_embeddings = (model_embedding.out_embedding.weight.data +
                       model_embedding.in_embedding.weight.data) / 2
    word_embeddings = word_embeddings.cpu().numpy()

    sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1])

    with open(embedding_txt, 'w') as f:
        for idx, item in enumerate(sorted_vocab_tuple):
            if item[0] == '\n':
                continue
            f.write(item[0] + ' ' +
                    ' '.join([str(i) for i in word_embeddings[idx]]) + '\n')

    glove_file = datapath(embedding_txt)
    temp_file = get_tmpfile(embedding_temp)
    _ = glove2word2vec(glove_file, temp_file)

    wv = KeyedVectors.load_word2vec_format(temp_file)
    wv.save(embedding)

    writer.close()


# if __name__ == '__main__':
#     base_path = '/media/druv022/Data2/Final'
#     paths = Paths(base_path, node2vec_type='1')

#     train_node2vec(paths)
示例#2
0
def train_node2vec(paths, params):
    dump_process_pkl = paths.dump_process
    dump_context_dict = paths.dump_context_dict
    dump_context_list = paths.dump_context_list
    dump_walks = paths.dump_walks
    save_model_path = paths.node2vec_base
    embedding_txt = paths.embedding_text
    embedding_temp = paths.embedding_temp
    embedding = paths.embedding
    mesh_graph_file = paths.MeSH_graph_disease

    if not params.randomize:
        np.random.seed(5)
        torch.manual_seed(5)
        random.seed(5)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    writer = SummaryWriter()

    # ----------- Random walk --------------------
    directed_graph = False

    if not os.path.exists(dump_walks):
        num_walks = 30
        walk_length = 10
        nx_G = read_graph(mesh_graph_file, directed_graph)
        G = Graph(nx_G, is_directed=directed_graph, p=params.p, q=params.q)
        G.preprocess_transition_probs()
        walks = G.simulate_walks(num_walks, walk_length)
        with open(dump_walks, 'wb') as f:
            pickle.dump(walks, f)
    else:
        with open(dump_walks, 'rb') as f:
            walks = pickle.load(f)

    if os.path.exists(dump_process_pkl):
        with open(dump_process_pkl, 'rb') as f:
            vocab = pickle.load(f)
    else:
        vocab = Vocabulary(lower=False)
        vocab.add_documents(walks)
        vocab.build()

        with open(dump_process_pkl, 'wb') as f:
            pickle.dump(vocab, f)

    # ---------- build embedding model ----------
    mesh_file = paths.MeSH_file
    ELMO_folder = paths.elmo_folder
    options_file = paths.elmo_options
    weight_file = paths.elmo_weights

    elmo = Elmo(options_file, weight_file, 2, dropout=0)
    elmo.to(device)

    mesh_graph = nx.read_gpickle(mesh_graph_file)
    mesh_graph = mesh_graph.to_undirected()

    mesh_dict = read_mesh_file(mesh_file)

    # Get the list of nodes (idx 0 is '<pad>')
    node_list = list(vocab.vocab.keys())

    # create weight matrix by using node_list order(which correspond to original vocab index order)
    elmo_embedding_dim = 1024
    if not os.path.exists(os.path.join(ELMO_folder, 'elmo_weights')):
        weight_list = []
        for idx, i in enumerate(node_list):
            if i in mesh_dict:
                node_idx = vocab.token_to_id(i)
                scope_note = mesh_dict[i].scope_note
                character_ids = batch_to_ids(scope_note).to(device)
                elmo_embeddings = elmo(character_ids)
                embeddings = elmo_embeddings['elmo_representations'][0]
                mask = elmo_embeddings['mask']
                embeddings = embeddings * mask.unsqueeze(2).expand(
                    mask.shape[0], mask.shape[1], embeddings.shape[2]).float()
                embeddings = embeddings.mean(dim=0).mean(dim=0)  # average
                weight_list.append(embeddings.cpu())
            else:
                weight_list.append(torch.zeros(elmo_embedding_dim))

        with open(os.path.join(ELMO_folder, 'elmo_weights'), 'wb') as f:
            pickle.dump(weight_list, f)
    else:
        with open(os.path.join(ELMO_folder, 'elmo_weights'), 'rb') as f:
            weight_list = pickle.load(f)

    weight = torch.stack(weight_list, dim=0)

    # ---------- train SkipGram -----------------
    epochs = params.epochs
    batch_size = params.batch_size
    window = params.window
    num_neg_sample = params.num_neg_sample
    writer = SummaryWriter()

    # use transformation only once, i.e either during creating the context dict and list or during training
    if not os.path.exists(dump_context_dict):
        l, d = multiprocess(walks, window=window, transform=vocab.doc2id)
        with open(dump_context_dict, 'wb') as f:
            pickle.dump(d, f)
        with open(dump_context_list, 'wb') as f:
            pickle.dump(l, f)
    else:
        with open(dump_context_dict, 'rb') as f:
            d = pickle.load(f)
        with open(dump_context_list, 'rb') as f:
            l = pickle.load(f)

    # here transformation is required we will directly sample the index
    sample_table = negative_sampling_table(vocab.token_counter(),
                                           transform=vocab.token_to_id)
    neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample))

    context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None)
    context_dataloader = DataLoader(context_data,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    pin_memory=True,
                                    num_workers=6)

    model_embedding = SkipGramModified(len(vocab.vocab),
                                       embedding_size=elmo_embedding_dim,
                                       weight=weight)
    model_embedding.to(device)
    optimizer_FC = torch.optim.Adam(list(model_embedding.parameters()),
                                    lr=0.005)  #+list(model_fc.parameters()

    train(model_embedding,
          optimizer_FC,
          context_dataloader,
          epochs,
          device,
          neg_sample,
          n_sample=num_neg_sample,
          writer=writer,
          save_path=save_model_path,
          l=l,
          d=d,
          vocab=vocab,
          batch_size=batch_size)

    node_idx = []
    for item in node_list:
        node_idx.append(vocab.token_to_id(item))

    x = torch.tensor(node_idx, device=device)
    y = torch.zeros(x.shape, device=device)
    z = torch.zeros(x.shape, device=device)

    x, y, z = model_embedding(x, y, z)

    word_embeddings = x.cpu().detach().numpy()

    sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1])

    with open(embedding_txt, 'w') as f:
        for idx, item in enumerate(sorted_vocab_tuple):
            if item[0] == '\n':
                continue
            f.write(item[0] + ' ' +
                    ' '.join([str(i) for i in word_embeddings[idx]]) + '\n')

    glove_file = datapath(embedding_txt)
    temp_file = get_tmpfile(embedding_temp)
    _ = glove2word2vec(glove_file, temp_file)

    wv = KeyedVectors.load_word2vec_format(temp_file)
    wv.save(embedding)

    writer.close()
示例#3
0
def main():
    # Update path
    training_data = r'----------------/Data/Skipgram/hansards/training.en'
    dump_process_pkl = r'----------------/Data/Skipgram/hansards/processed_en_w.pkl'
    dump_context_dict = r'----------------/Data/Skipgram/hansards/context_dict_w.pkl'
    dump_context_list = r'----------------/Data/Skipgram/hansards/context_list_w.pkl'
    save_model_path = r'----------------/Data/Skipgram/hansards'
    embedding_txt = r'----------------/Data/Skipgram/hansards/embedding.txt'
    embedding_temp = r'----------------/Data/Skipgram/hansards/embedding_temp.txt'
    epochs = 20
    batch_size = 2**10
    window = 5
    num_neg_sample = 5
    writer = SummaryWriter()
    stopwords = set(stopwords.words('english'))

    with open(training_data, 'r') as f:
        data = f.readlines()
        data = [line.replace('\n', '').split(' ') for line in data]
        data = [[word for word in line if word not in stopwords]
                for line in data]

    if os.path.exists(dump_process_pkl):
        with open(dump_process_pkl, 'rb') as f:
            vocab = pickle.load(f)
    else:
        vocab = Vocabulary()
        vocab.add_documents(data)
        vocab.build()

        with open(dump_process_pkl, 'wb') as f:
            pickle.dump(vocab, f)

    # use transformation only once, i.e either during creating the context dict and list or during training
    if not os.path.exists(dump_context_dict):
        l, d = multiprocess(data, window=window, transform=vocab.doc2id)
        with open(dump_context_dict, 'wb') as f:
            pickle.dump(d, f)
        with open(dump_context_list, 'wb') as f:
            pickle.dump(l, f)
    else:
        with open(dump_context_dict, 'rb') as f:
            d = pickle.load(f)
        with open(dump_context_list, 'rb') as f:
            l = pickle.load(f)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # here transformation is required we will directly sample the index
    sample_table = negative_sampling_table(vocab.token_counter(),
                                           transform=vocab.token_to_id)
    neg_sample = np.random.choice(sample_table, size=(len(l), num_neg_sample))

    context_data = ContextData(l, d, neg_sample, n_sample=5, transform=None)
    context_dataloader = DataLoader(context_data,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=6)

    model_embedding = SkipGram(len(vocab.vocab), embedding_size=200)
    model_embedding.load_state_dict(
        torch.load(os.path.join(save_model_path, 'sk_model5_5.pkl')))
    model_embedding.to(device)
    optimizer_embedding = torch.optim.SparseAdam(model_embedding.parameters(),
                                                 lr=0.005)

    train(model_embedding,
          optimizer_embedding,
          context_dataloader,
          epochs,
          device,
          neg_sample,
          n_sample=num_neg_sample,
          save_path=save_model_path)
    word_embeddings = (model_embedding.out_embedding.weight.data +
                       model_embedding.in_embedding.weight.data) / 2
    word_embeddings = word_embeddings.cpu().numpy()

    sorted_vocab_tuple = sorted(vocab.vocab.items(), key=lambda kv: kv[1])

    with open(embedding_txt, 'w') as f:
        for idx, item in enumerate(sorted_vocab_tuple):
            if item[0] == '\n':
                continue
            f.write(item[0] + ' ' +
                    ' '.join([str(i) for i in word_embeddings[idx]]) + '\n')

    glove_file = datapath(embedding_txt)
    temp_file = get_tmpfile(embedding_temp)
    _ = glove2word2vec(glove_file, temp_file)

    wv = KeyedVectors.load_word2vec_format(temp_file)

    result = wv.most_similar(positive=['woman', 'king'], negative=['man'])
    print("{}: {:.4f}".format(*result[0]))

    writer.close()