def make_word_embedding_matrix_file(vocab, dataset_name='9sr'): """Makes a word embedding matrix file, ready to be loaded as the weights of an embedding layer in a NN. Reads the vocabulary of a dataset as returned by get_vocab. (This must exist). Saves the word embedding matrix in the word_embeddings_dir with appropriate name. Args: vocab: vocabulary for that dataset. word -> id dataset_name: String that defines the dataset name and location of vocab. """ np.random.seed(12345) # random but reproducable glove_emb = load_word_embeddings_from_file(vocab=vocab) print_intersection(vocab, glove_emb) # word_emb_array = np.random.random((len(vocab) + 1, 300)) # +1 is only for my code word_emb_array = np.random.random((len(vocab), 300)) for w, i in vocab.items(): if glove_emb.get(w) is not None: word_emb_array[i] = np.array(glove_emb.get(w)) filename = os.path.join(word_embeddings_dir, 'query_completion', '%s_glove_matrix_300d.pkl' % dataset_name) save_pickle(filename, word_emb_array) return word_emb_array, glove_emb
def json2vocab(filenames, vocab_filename, vocab_size, valid_users=None, valid_subreddits=None, overwrite=False): """Reads all the .json files and keeps the top words mentioned in them by the valid users and subreddits. Args: filenames: list of paths where the .json files are. vocab_filename: String with the path of the vocabulary. vocab_size: Total number of words to be used, ie top-k limit. valid_users: Set of users whose words should be kept. valid_subreddits: Set of subreddits whose words should be kept. overwrite: Whether to overwrite existing file. Returns: A set of words. Saves: A set of words and a text file with word, count (for sanity check). """ counter_filename = vocab_filename.replace('pkl', 'txt') if os.path.exists(vocab_filename) and not overwrite: return load_pickle(vocab_filename, False) print 'Making:\n%s\n%s' % (vocab_filename, counter_filename) counters = [] limit = vocab_size pool = mp.Pool(n_proc) proc_data_size = int(np.ceil(1. * len(filenames) / n_proc)) for i in range(n_proc): proc_filenames = filenames[i * proc_data_size:(i + 1) * proc_data_size] if len(proc_filenames) > 0: pool.apply_async(_json2vocab_mp, args=(i, proc_filenames, valid_subreddits, valid_users), callback=counters.append) pool.close() pool.join() combined_counters = combine_dicts(counters) print 'Total words before pruning were %d' % len(combined_counters) sorted_vocab = sorted(combined_counters.items(), key=lambda x: x[1], reverse=True) vocab = set([x[0] for x in sorted_vocab[:limit - 3]]) # explicity add unk and sentence start and end tokens. vocab.add('<unk>') vocab.add('<sent_end>') vocab.add('<sent_start>') vocab = set_to_dict(vocab, 1) # word ids start from 1!! save_pickle(vocab_filename, vocab) final_counter = np.array(sorted_vocab[:limit]) save_txt(counter_filename, final_counter, delimiter=' ', fmt='%s') return vocab
def create_valid_subreddit_set(subscribers_dict, subscriber_limit=1000, overwrite=False): """Make a set of subreddits with more than subscriber_limit subscribers, based on an input dictionary.""" subreddit_set_filename = get_valid_sub_name(subscriber_limit) if os.path.exists(subreddit_set_filename) and not overwrite: return load_pickle(subreddit_set_filename, False) sub_set = set() # get it? for (subreddit, subscriber_count) in subscribers_dict.iteritems(): if subscriber_count >= subscriber_limit: sub_set.add(subreddit) print '-->Sub set has %d subreddits' % len(sub_set) save_pickle(subreddit_set_filename, sub_set) return sub_set
def create_valid_user_set(params, years=None, overwrite=False): """Creates a set of valid users, meaning users with more than min_posts posts""" filename = get_valid_user_filename(params, years) if os.path.exists(filename) and not overwrite: return load_pickle(filename, False) user_counts = combine_dicts(get_user_count_dictionaries(params, years)) usernames = get_top_users(params.min_posts, user_counts) usernames = remove_bots(usernames, params) print '--> Total valid users: %d' % len(usernames) save_pickle(filename, usernames) return usernames
def lm_valid_users(filename, params, uxs=None, user_names=None, years=None, overwrite=False): """Creates the valid users for language modeling, after applying all previous filters + min h_index filter. Args: filename: language model valid users filename to be saved. params: Parameters of the preprocessing run uxs: User by Subreddit count matrix (sparse). user_names: dictionary from user id to user_name. years: list of all the years we want to take into consideration. If none, it selects all available. overwrite: Boolean to define whether to overwrite existing file. Returns: A set of user names (who all had an h_index larger than that specified in params. """ if os.path.exists(filename) and not overwrite: return load_pickle(filename, False) if user_names is None: user_set = create_valid_user_set(params, years, overwrite) user_names = invert_dict(set_to_dict(user_set)) if uxs is None: uxs = data_to_sparse( dict2matrix(params, valid_users=user_set, years=years)) assert len(user_names) == uxs.shape[0] print 'Calculating h-indices...' user_h_index = [] for u in range(uxs.shape[0]): counts = sorted(uxs.getrow(u).data, reverse=True) user_h_index.append(get_h_index(counts)) user_h_index = np.array(user_h_index) top_users = np.where(user_h_index >= params.h_index_min)[0] top_usernames = set([user_names[u] for u in top_users]) print 'Total Users: %d -> after pruning with at least %d h-index, %d user left' % ( len(user_names), params.h_index_min, len(top_usernames)) save_pickle(filename, top_usernames) return top_usernames
def crawl_subreddit_subscribers(subreddit_limit, subscribers_dict, subscribers_dict_filename): """Crawls dictionary from subreddit -> number of subscribers from http://redditmetrics.com/top/""" from lxml import html for page_offset in range(0, subreddit_limit, 100): url_str = 'http://redditmetrics.com/top/offset/%d' % page_offset print url_str, len(subscribers_dict) time.sleep(1 + 3 * random.random()) page = requests.get(url_str) tree = html.fromstring(page.content) table = tree.xpath('//td[@class="tod"]') for i in range(100): # every second entry is name, every third is number of subscribers subscribers_dict[table[3 * i + 1].iterlinks().next()[2][3:]] = int( table[3 * i + 2].text.replace(',', '')) save_pickle(subscribers_dict_filename, subscribers_dict) return subscribers_dict
def _json2dicts_mp(proc_id, filename, user_count_filename, uc_dict_filename, params, overwrite=False): """Processes a file as downloaded from the reddit data webpage: http://files.pushshift.io/reddit/comments/ For each file it creates two files. One with a dictionary from user -> count and one from user_category -> count. """ print '\t%d Converting %s for at least %d subscribers' % ( proc_id, os.path.basename(filename), params.min_subscribers), if os.path.exists(uc_dict_filename) and not overwrite: print ': exists! Moving on' return print count_dict = {} user_post_count = {} f = open(filename, 'r') i = 0 limit = 1 start_time = time.time() for line in f: i += 1 if i % limit == 0: time_passed = time.time() - start_time print '\t%d %d posts, unique user-subreddit pairs: %d, time passed: %02f' % ( proc_id, i, len(count_dict), time_passed) limit *= 2 entry = json.loads(line) if is_valid_entry(): # write this k = '%s %s' % (entry['author'], entry['subreddit'] ) # key is author + ' ' + subreddit count_dict[k] = count_dict.get(k, 0) + 1 user_post_count[entry['author']] = user_post_count.get( entry['author'], 0) + 1 time_passed = time.time() - start_time print '\t%d %d posts, size of uc_dict: %d, time passed: %02f' % ( proc_id, i, len(count_dict), time_passed) print '\t%d %d users' % (proc_id, len(user_post_count)) save_pickle(uc_dict_filename, count_dict) save_pickle(user_count_filename, user_post_count)
def train_mixture_model(train, val, test, method='logP', recall_k=100, dataset_name='d_name', overwrite=False, num_proc=None): """ Runs the main experiment of the paper, finding the best mixing weights per user for these two components. Learns the weights per user and saves them in a file. If file exists it just loads it. It evaluates on the test set. There is a memory component, where a person has been in the past (exploit), and global component, which is the population preferences (explore). Data come in COO form. That is a numpy array of (N x 3) where each row is the (row, column, value) triplet of the sparse array Users x Categories. N is the number of entries in the array. :param train: train data COO matrix :param val: validation data COO matrix :param test: test data COO matrix :param method: Method of evaluation. Can be 'logP' or 'recall' for log probability per event, or recall@k :param recall_k: the k for recall@k. If method is 'logP' this does nothing. :param dataset_name: Name of the directory the results will be saved. :param overwrite: Boolean, on whether to overwrite learned weights or read them if they exist. :param num_proc: Number of processes to be used. If none, all the processors in the machine will be used. :return: returns an array of mixing weights, which is n_users x 2 (2 components, self and global) """ filename = os.path.join(results_dir, 'mixture_model', dataset_name, 'mixing_weights.pkl') if os.path.exists(filename) and not overwrite: mix_weights = load_pickle(filename, False) else: train_matrix, global_matrix = get_train_global(train, val, test) components = [train_matrix, global_matrix] # can add more components here mix_weights = learn_mixing_weights(components, val, num_proc=num_proc) save_pickle(filename, mix_weights, False) evaluate_method(train, val, test, mix_weights, method, recall_k) return mix_weights
def main(config, progress): # save config with open("./log/configs.json", "a") as f: json.dump(config, f) f.write("\n") cprint("*" * 80) cprint("Experiment progress: {0:.2f}%".format(progress * 100)) cprint("*" * 80) metrics = {} # data hyper-params data_path = config["data_path"] keyword_path = config["keyword_path"] pretrained_wordvec_path = config["pretrained_wordvec_path"] data_dir = "/".join(data_path.split("/")[:-1]) dataset = data_path.split("/")[-2] # convai2 or casual test_mode = bool(config["test_mode"]) save_model_path = config["save_model_path"] load_kw_prediction_path = config["load_kw_prediction_path"] min_context_len = config["min_context_len"] max_context_len = config["max_context_len"] max_sent_len = config["max_sent_len"] max_keyword_len = config["max_keyword_len"] max_vocab_size = config["max_vocab_size"] max_keyword_vocab_size = config["max_keyword_vocab_size"] flatten_context = config["flatten_context"] # model hyper-params config_id = config["config_id"] model = config["model"] use_CN_hopk_graph = config["use_CN_hopk_graph"] use_utterance_concepts = use_CN_hopk_graph > 0 concept_encoder = config["concept_encoder"] combine_word_concepts = config["combine_word_concepts"] gnn = config["gnn"] encoder = config["encoder"] aggregation = config["aggregation"] use_keywords = bool(config["use_keywords"]) keyword_score_weight = config["keyword_score_weight"] keyword_encoder = config["keyword_encoder"] # mean, max, GRU, any_max embed_size = config["embed_size"] use_pretrained_word_embedding = bool( config["use_pretrained_word_embedding"]) fix_word_embedding = bool(config["fix_word_embedding"]) gnn_hidden_size = config["gnn_hidden_size"] gnn_layers = config["gnn_layers"] encoder_hidden_size = config["encoder_hidden_size"] encoder_layers = config["encoder_layers"] n_heads = config["n_heads"] dropout = config["dropout"] # training hyper-params batch_size = config["batch_size"] epochs = config["epochs"] lr = config["lr"] lr_decay = config["lr_decay"] seed = config["seed"] device = torch.device(config["device"]) fp16 = bool(config["fp16"]) fp16_opt_level = config["fp16_opt_level"] # set seed random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) if "convai2" in data_dir and min_context_len != 2: raise ValueError("convai2 dataset has min context len of 2") if use_pretrained_word_embedding and str( embed_size) not in pretrained_wordvec_path: raise ValueError( "embedding size and pretrained_wordvec_path not match") if use_keywords and load_kw_prediction_path == "": raise ValueError( "kw model path needs to be provided when use_keywords is True") # load data cprint("Loading conversation data...") train, valid, test = load_pickle(data_path) train_keyword, valid_keyword, test_keyword = load_pickle(keyword_path) train_candidate, valid_candidate = None, None # load 20 candidates train_candidate, valid_candidate, test_candidate = load_pickle( os.path.join(data_dir, "candidate.pkl")) if test_mode: cprint("Testing model...") train = train + valid train_keyword = train_keyword + valid_keyword valid = test valid_keyword = test_keyword train_candidate = train_candidate + valid_candidate valid_candidate = test_candidate cprint("sample train: ", train[0]) cprint("sample train keyword: ", train_keyword[0]) cprint("sample valid: ", valid[0]) cprint("sample valid keyword: ", valid_keyword[0]) # clip and pad data train_padded_convs, train_padded_keywords = pad_and_clip_data( train, train_keyword, min_context_len, max_context_len + 1, max_sent_len, max_keyword_len) valid_padded_convs, valid_padded_keywords = pad_and_clip_data( valid, valid_keyword, min_context_len, max_context_len + 1, max_sent_len, max_keyword_len) train_padded_candidates = pad_and_clip_candidate(train_candidate, max_sent_len) valid_padded_candidates = pad_and_clip_candidate(valid_candidate, max_sent_len) # build vocab if "convai2" in data_dir: test_padded_convs, _ = pad_and_clip_data(test, test_keyword, min_context_len, max_context_len + 1, max_sent_len, max_keyword_len) word2id = build_vocab(train_padded_convs + valid_padded_convs + test_padded_convs, max_vocab_size) # use entire dataset for vocab else: word2id = build_vocab(train_padded_convs, max_vocab_size) keyword2id = build_vocab(train_padded_keywords, max_keyword_vocab_size) id2keyword = {idx: w for w, idx in keyword2id.items()} for w in keyword2id: if w not in word2id: word2id[w] = len(word2id) # add OOV keywords to word2id id2word = {idx: w for w, idx in word2id.items()} cprint("keywords that are not in word2id: ", set(keyword2id.keys()) - set(word2id.keys())) vocab_size = len(word2id) keyword_vocab_size = len(keyword2id) cprint("vocab size: ", vocab_size) cprint("keyword vocab size: ", keyword_vocab_size) # create a mapping from keyword id to word id keywordid2wordid = None train_candidate_keyword_ids, valid_candidate_keyword_ids = None, None if use_keywords: keywordid2wordid = [ word2id[id2keyword[i]] if id2keyword[i] in word2id else word2id["<unk>"] for i in range(len(keyword2id)) ] keywordid2wordid = torch.LongTensor(keywordid2wordid).to(device) # load candidate keywords candidate_keyword_path = os.path.join(data_dir, "candidate_keyword.pkl") if os.path.exists(candidate_keyword_path): cprint("Loading candidate keywords from ", candidate_keyword_path) train_candidate_keywords, valid_candidate_keywords, test_candidate_keywords = load_pickle( candidate_keyword_path) else: cprint("Creating candidate keywords...") train_candidate_keywords = extract_keywords_from_candidates( train_candidate, keyword2id) valid_candidate_keywords = extract_keywords_from_candidates( valid_candidate, keyword2id) test_candidate_keywords = extract_keywords_from_candidates( test_candidate, keyword2id) save_pickle((train_candidate_keywords, valid_candidate_keywords, test_candidate_keywords), candidate_keyword_path) if test_mode: train_candidate_keywords = train_candidate_keywords + valid_candidate_keywords valid_candidate_keywords = test_candidate_keywords # pad cprint("Padding candidate keywords...") train_padded_candidate_keywords = pad_and_clip_candidate( train_candidate_keywords, max_keyword_len) valid_padded_candidate_keywords = pad_and_clip_candidate( valid_candidate_keywords, max_keyword_len) # convert candidates to ids cprint("Converting candidate keywords to ids...") train_candidate_keyword_ids = convert_candidates_to_ids( train_padded_candidate_keywords, keyword2id) valid_candidate_keyword_ids = convert_candidates_to_ids( valid_padded_candidate_keywords, keyword2id) # load CN graph CN_hopk_edge_index, CN_hopk_nodeid2wordid, keywordid2nodeid, node2id, CN_hopk_edge_matrix_mask = None, None, None, None, None if use_CN_hopk_graph > 0: cprint("Loading CN_hopk edge index...") """ CN_graph_dict: { edge_index: 2D list (num_edges, 2), edge_weight: list (num_edges, ), nodeid2wordid: 2D list (num_nodes, 10), edge_mask: numpy array of (keyword_vocab_size, keyword_vocab_size) } """ CN_hopk_graph_path = "./data/{0}/CN_graph_{1}hop_ge1.pkl".format( dataset, use_CN_hopk_graph) cprint("Loading graph from ", CN_hopk_graph_path) CN_hopk_graph_dict = load_nx_graph_hopk(CN_hopk_graph_path, word2id, keyword2id) CN_hopk_edge_index = torch.LongTensor( CN_hopk_graph_dict["edge_index"]).transpose(0, 1).to( device) # (2, num_edges) CN_hopk_nodeid2wordid = torch.LongTensor( CN_hopk_graph_dict["nodeid2wordid"]).to(device) # (num_nodes, 10) node2id = CN_hopk_graph_dict["node2id"] id2node = {idx: w for w, idx in node2id.items()} keywordid2nodeid = [ node2id[id2keyword[i]] if id2keyword[i] in node2id else node2id["<unk>"] for i in range(len(keyword2id)) ] keywordid2nodeid = torch.LongTensor(keywordid2nodeid).to(device) cprint("edge index shape: ", CN_hopk_edge_index.shape) cprint("edge index[:,:8]", CN_hopk_edge_index[:, :8]) cprint("nodeid2wordid shape: ", CN_hopk_nodeid2wordid.shape) cprint("nodeid2wordid[:5,:8]", CN_hopk_nodeid2wordid[:5, :8]) cprint("keywordid2nodeid shape: ", keywordid2nodeid.shape) cprint("keywordid2nodeid[:8]", keywordid2nodeid[:8]) # convert tokens to ids train_conv_ids = convert_convs_to_ids(train_padded_convs, word2id) valid_conv_ids = convert_convs_to_ids(valid_padded_convs, word2id) train_keyword_ids = convert_convs_to_ids(train_padded_keywords, keyword2id) valid_keyword_ids = convert_convs_to_ids(valid_padded_keywords, keyword2id) train_candidate_ids, valid_candidate_ids = None, None train_candidate_ids = convert_candidates_to_ids(train_padded_candidates, word2id) valid_candidate_ids = convert_candidates_to_ids(valid_padded_candidates, word2id) keyword_mask_matrix = None if use_CN_hopk_graph > 0: keyword_mask_matrix = torch.from_numpy( CN_hopk_graph_dict["edge_mask"]).float( ) # numpy array of (keyword_vocab_size, keyword_vocab_size) cprint("building keyword mask matrix...") keyword_mask_matrix[ torch.arange(keyword_vocab_size), torch.arange(keyword_vocab_size)] = 0 # remove self loop cprint("keyword mask matrix non-zeros ratio: ", keyword_mask_matrix.mean()) cprint("average number of neighbors: ", keyword_mask_matrix.sum(dim=1).mean()) cprint("sample keyword mask matrix: ", keyword_mask_matrix[:8, :8]) keyword_mask_matrix = keyword_mask_matrix.to(device) num_examples = len(train_conv_ids) cprint("sample train token ids: ", train_conv_ids[0]) cprint("sample train keyword ids: ", train_keyword_ids[0]) cprint("sample valid token ids: ", valid_conv_ids[0]) cprint("sample valid keyword ids: ", valid_keyword_ids[0]) cprint("sample train candidate ids: ", train_candidate_ids[0]) cprint("sample valid candidate ids: ", valid_candidate_ids[0]) if use_keywords: cprint("sample train candidate keyword ids: ", train_candidate_keyword_ids[0]) cprint("sample valid candidate keyword ids: ", valid_candidate_keyword_ids[0]) # create model if model in ["CoGraphMatcher"]: model_kwargs = { "embed_size": embed_size, "vocab_size": vocab_size, "gnn_hidden_size": gnn_hidden_size, "gnn_layers": gnn_layers, "encoder_hidden_size": encoder_hidden_size, "encoder_layers": encoder_layers, "n_heads": n_heads, "CN_hopk_edge_matrix_mask": CN_hopk_edge_matrix_mask, "nodeid2wordid": CN_hopk_nodeid2wordid, "keywordid2wordid": keywordid2wordid, "keywordid2nodeid": keywordid2nodeid, "concept_encoder": concept_encoder, "gnn": gnn, "encoder": encoder, "aggregation": aggregation, "use_keywords": use_keywords, "keyword_score_weight": keyword_score_weight, "keyword_encoder": keyword_encoder, "dropout": dropout, "combine_word_concepts": combine_word_concepts } # create keyword model kw_model = "" use_last_k_utterances = -1 if use_keywords: kw_model = load_kw_prediction_path.split( "/")[-1][:-3] # keyword prediction model name if "GNN" in kw_model: kw_model = "KW_GNN" use_last_k_utterances = 2 # load pretrained model cprint("Loading weights from ", load_kw_prediction_path) kw_model_checkpoint = torch.load(load_kw_prediction_path, map_location=device) if "word2id" in kw_model_checkpoint: keyword2id = kw_model_checkpoint.pop("word2id") if "model_kwargs" in kw_model_checkpoint: kw_model_kwargs = kw_model_checkpoint.pop("model_kwargs") kw_model = globals()[kw_model](**kw_model_kwargs) kw_model.load_state_dict(kw_model_checkpoint) kw_model.to(device) kw_model.eval() # set to evaluation mode, no training required cprint("Building model...") model = globals()[config["model"]](**model_kwargs) cprint("Initializing pretrained word embeddings...") pretrained_word_embedding = None if use_pretrained_word_embedding: # load pretrained word embedding cprint("Loading pretrained word embeddings...") pretrained_wordvec_name = pretrained_wordvec_path.split("/")[-1][:-4] word_vectors_path = os.path.join( data_dir, "word_vectors_{0}.pkl".format(pretrained_wordvec_name)) if os.path.exists(word_vectors_path): cprint("Loading pretrained word embeddings from ", word_vectors_path) with open(word_vectors_path, "rb") as f: word_vectors = pickle.load(f) else: cprint("Loading pretrained word embeddings from scratch...") word_vectors = load_vectors(pretrained_wordvec_path, word2id) cprint("Saving pretrained word embeddings to ", word_vectors_path) with open(word_vectors_path, "wb") as f: pickle.dump(word_vectors, f) cprint("pretrained word embedding size: ", len(word_vectors)) pretrained_word_embedding = np.zeros((len(word2id), embed_size)) for w, i in word2id.items(): if w in word_vectors: pretrained_word_embedding[i] = np.array(word_vectors[w]) else: pretrained_word_embedding[i] = np.random.randn(embed_size) / 9 pretrained_word_embedding[0] = 0 # 0 for PAD embedding pretrained_word_embedding = torch.from_numpy( pretrained_word_embedding).float() cprint("word embedding size: ", pretrained_word_embedding.shape) model.init_embedding(pretrained_word_embedding, fix_word_embedding) cprint(model) cprint("number of parameters: ", count_parameters(model)) model.to(device) # optimization amp = None if fp16: from apex import amp optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = LambdaLR(optimizer, lr_lambda=lambda step: 1 / (1 + lr_decay * step / (num_examples / batch_size))) if fp16: model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) # training epoch_train_losses = [] epoch_valid_losses = [] epoch_valid_precisions = [] epoch_valid_recalls = [] epoch_valid_MRRs = [] best_model_statedict = {} cprint("Start training...") for epoch in range(epochs): cprint("-" * 80) cprint("Epoch", epoch + 1) train_batches = create_batches_retrieval(train_conv_ids, train_keyword_ids, train_candidate_ids, train_candidate_keyword_ids, \ 2*max_keyword_len, batch_size, shuffle=True, use_keywords=use_keywords, use_candidate_keywords=use_keywords, use_utterance_concepts=use_utterance_concepts, \ node2id=node2id, id2word=id2word, flatten_context=flatten_context, use_last_k_utterances=use_last_k_utterances) valid_batches = create_batches_retrieval(valid_conv_ids, valid_keyword_ids, valid_candidate_ids, valid_candidate_keyword_ids, \ 2*max_keyword_len, batch_size, shuffle=False, use_keywords=use_keywords, use_candidate_keywords=use_keywords, use_utterance_concepts=use_utterance_concepts, \ node2id=node2id, id2word=id2word, flatten_context=flatten_context, use_last_k_utterances=use_last_k_utterances) if epoch == 0: cprint("number of optimization steps per epoch: ", len(train_batches)) # 3361 cprint("train batches 1st example: ") for k, v in train_batches[0].items(): if k == "batch_context": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_candidates": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_kw": cprint("\n", k, v[0], [id2keyword[w] for w in v[0]]) if k == "batch_candidates_kw": utters = [] for utter in v[0]: utters.append([id2keyword[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_concepts": if len(v[0][0]) > 0: utters = [] for utter in v[0]: utters.append([id2node[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_candidates_concepts": utters = [] for utter in v[0]: utters.append([id2node[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_for_keyword_prediction": utters = [] for utter in v[0]: utters.append([id2word[w] for w in utter]) cprint("\n", k, v[0], utters) if k == "batch_context_concepts_for_keyword_prediction": cprint("\n", k, v[0], [id2node[w] for w in v[0]]) model.train() train_loss, (_, _, _) = run_epoch(train_batches, model, optimizer, training=True, device=device, fp16=fp16, amp=amp, \ kw_model=kw_model, keyword_mask_matrix=keyword_mask_matrix, step_scheduler=scheduler, keywordid2wordid=keywordid2wordid, \ CN_hopk_edge_index=CN_hopk_edge_index) model.eval() valid_loss, (valid_precision, valid_recall, valid_MRR) = run_epoch(valid_batches, model, optimizer, training=False, device=device, \ kw_model=kw_model, keyword_mask_matrix=keyword_mask_matrix, keywordid2wordid=keywordid2wordid, CN_hopk_edge_index=CN_hopk_edge_index) # scheduler.step() cprint( "Config id: {0}, Epoch {1}: train loss: {2:.4f}, valid loss: {3:.4f}, valid precision: {4}, valid recall: {5}, valid MRR: {6}" .format(config_id, epoch + 1, train_loss, valid_loss, valid_precision, valid_recall, valid_MRR)) if scheduler is not None: cprint("Current learning rate: ", scheduler.get_last_lr()) epoch_train_losses.append(train_loss) epoch_valid_losses.append(valid_loss) epoch_valid_precisions.append(valid_precision) epoch_valid_recalls.append(valid_recall) epoch_valid_MRRs.append(valid_MRR) if save_model_path != "": if epoch == 0: for k, v in model.state_dict().items(): best_model_statedict[k] = v.cpu() else: if epoch_valid_recalls[-1][0] == max( [recall1 for recall1, _, _ in epoch_valid_recalls]): for k, v in model.state_dict().items(): best_model_statedict[k] = v.cpu() # early stopping if len(epoch_valid_recalls) >= 3 and epoch_valid_recalls[-1][ 0] < epoch_valid_recalls[-2][0] and epoch_valid_recalls[-2][ 0] < epoch_valid_recalls[-3][0]: break config.pop("seed") config.pop("config_id") metrics["config"] = config metrics["score"] = max([recall[0] for recall in epoch_valid_recalls]) metrics["epoch"] = np.argmax([recall[0] for recall in epoch_valid_recalls]).item() metrics["recall"] = epoch_valid_recalls[metrics["epoch"]] metrics["MRR"] = epoch_valid_MRRs[metrics["epoch"]] metrics["precision"] = epoch_valid_precisions[metrics["epoch"]] if save_model_path and seed == 1: cprint("Saving model to ", save_model_path) best_model_statedict["word2id"] = word2id best_model_statedict["model_kwargs"] = model_kwargs torch.save(best_model_statedict, save_model_path) return metrics
# load candidate keywords candidate_keyword_path = "./data/{0}/candidate_pool_keyword.pkl".format(dataset) candidate_keyword_ids = [] if os.path.exists(candidate_keyword_path): print("Loading candidate keywords from ", candidate_keyword_path) candidate_keyword_ids = load_pickle(candidate_keyword_path) if len(candidate_keyword_ids) != len(candidate_pool): print("Creating candidate keywords...") candidate_keyword_ids = [] pad_token = "<pad>" for cand in tqdm(candidate_pool): cand_kw_tokens = pad_sentence(kw_tokenize(cand), 10, pad_token) cand_kw_tokens = [keyword2id[w] if w in keyword2id else keyword2id["<unk>"] for w in cand_kw_tokens] candidate_keyword_ids.append(cand_kw_tokens) save_pickle(candidate_keyword_ids, candidate_keyword_path) print("sample candidate_pool_ids: ", len(candidate_pool_ids), candidate_pool_ids[:3], [[id2word[t] for t in u if t!=0] for u in candidate_pool_ids[:3]]) print("sample candidate_concept_ids: ", len(candidate_concept_ids), candidate_concept_ids[:3], [[id2node[t] for t in u if t!=0] for u in candidate_concept_ids[:3]]) print("sample candidate_keyword_ids: ", len(candidate_keyword_ids), candidate_keyword_ids[:3], [[id2keyword[t] for t in u if t!=0] for u in candidate_keyword_ids[:3]]) # encode candidates for CoGraphMatcher print("Encoding candidate pool for chat model...") with torch.no_grad(): chunk_size = 2000 chunk_ids = list(range(0, len(candidate_pool_ids), chunk_size)) + [len(candidate_pool_ids)] chunk_candidate_outs = [] for s, e in zip(chunk_ids[:-1], chunk_ids[1:]): chunk_candidate_outs.append(chat_model.encode_candidate_offline(\ torch.LongTensor(candidate_pool_ids[s:e]).to(device).unsqueeze(0), \ torch.LongTensor(candidate_concept_ids[s:e]).to(device).unsqueeze(0), \