def sentence_retrieval(self): print('- sentence retrieval: initialise') word_dict = pickle.load(open(self.path_word_dict_stage_2, "rb")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(self.path_stage_2_model) vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0) embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1) hidden_size = checkpoint["model"]["_projection.0.weight"].size(0) num_classes = checkpoint["model"]["_classification.4.weight"].size(0) use_oov_flag = 0 if 'oov' in self.embeddings_settings_sentence_retrieval_list: use_oov_flag = 1 use_pos_tag_flag = 0 if 'pos' in self.embeddings_settings_sentence_retrieval_list: use_pos_tag_flag = 1 model = ESIM(vocab_size, embedding_dim, hidden_size, num_classes=num_classes, use_pos_tag_flag=use_pos_tag_flag, use_oov_flag=use_oov_flag, device=device).to(device) model.load_state_dict(checkpoint["model"]) model.eval() print('- sentence retrieval: iterate through claims') for claim_nr in tqdm(range(self.nr_claims)): path_claim = os.path.join(self.path_document_retrieval_dir, str(claim_nr) + '.json') claim_dict = dict_load_json(path_claim) list_prob = [] list_doc_nr = [] list_line_nr = [] for doc_nr in claim_dict['document_retrieval']: for line_nr in claim_dict['document_retrieval'][doc_nr]: if 'sentence_retrieval' not in claim_dict: claim_dict['sentence_retrieval'] = {} if doc_nr not in claim_dict['sentence_retrieval']: claim_dict['sentence_retrieval'][doc_nr] = {} if line_nr not in claim_dict['sentence_retrieval'][doc_nr]: claim_dict['sentence_retrieval'][doc_nr][line_nr] = {} prob = compute_prob_stage_2(model, claim_dict, doc_nr, line_nr, device) claim_dict['sentence_retrieval'][doc_nr][line_nr][ 'prob'] = prob list_doc_nr.append(doc_nr) list_line_nr.append(line_nr) list_prob.append(prob) sorted_list_doc_nr = sort_list(list_doc_nr, list_prob)[-5:] sorted_list_line_nr = sort_list(list_line_nr, list_prob)[-5:] sorted_list_prob = sort_list(list_prob, list_prob)[-5:] claim_dict['sentence_retrieval'][ 'doc_nr_list'] = sorted_list_doc_nr claim_dict['sentence_retrieval'][ 'line_nr_list'] = sorted_list_line_nr claim_dict['sentence_retrieval']['prob_list'] = sorted_list_prob claim_dict['predicted_evidence'] = [] for i in range(len(sorted_list_doc_nr)): doc_nr = sorted_list_doc_nr[i] title = wiki_database.get_title_from_id(int(doc_nr)) line_nr = int(sorted_list_line_nr[i]) claim_dict['predicted_evidence'].append([title, line_nr]) path_save = os.path.join(self.path_sentence_retrieval_dir, str(claim_nr) + '.json') self.save_dict(claim_dict, path_save)
def label_prediction(self): print('- label prediction: initialise') word_dict = pickle.load(open(self.path_word_dict_stage_3, "rb")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(self.path_stage_3_model) vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0) embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1) hidden_size = checkpoint["model"]["_projection.0.weight"].size(0) num_classes = checkpoint["model"]["_classification.4.weight"].size(0) use_oov_flag = 0 if 'oov' in self.embeddings_settings_label_prediction_list: use_oov_flag = 1 use_pos_tag_flag = 0 if 'pos' in self.embeddings_settings_label_prediction_list: use_pos_tag_flag = 1 model = ESIM(vocab_size, embedding_dim, hidden_size, num_classes=num_classes, use_pos_tag_flag=use_pos_tag_flag, use_oov_flag=use_oov_flag, device=device).to(device) model.load_state_dict(checkpoint["model"]) model.eval() print('- label prediction: iterate through claims') for claim_nr in tqdm(range(self.nr_claims)): path_claim = os.path.join(self.path_sentence_retrieval_dir, str(claim_nr) + '.json') claim_dict = dict_load_json(path_claim) prob_list = [] prob_list_supported = [] prob_list_refuted = [] for i in range(len( claim_dict['sentence_retrieval']['doc_nr_list'])): doc_nr = claim_dict['sentence_retrieval']['doc_nr_list'][i] line_nr = claim_dict['sentence_retrieval']['line_nr_list'][i] if doc_nr in claim_dict['document_retrieval']: if line_nr in claim_dict['document_retrieval'][doc_nr]: prob = compute_prob_stage_3(model, claim_dict, doc_nr, line_nr, device) prob_list.append(prob) prob_list_supported.append(prob[2]) prob_list_refuted.append(prob[1]) else: print('line_nr not in list', line_nr) else: print('doc_nr not in list', doc_nr) if max(prob_list_supported) > 0.5: claim_dict['predicted_label'] = 'SUPPORTS' elif max(prob_list_refuted) > 0.5: claim_dict['predicted_label'] = 'REFUTES' else: claim_dict['predicted_label'] = 'NOT ENOUGH INFO' path_save = os.path.join(self.path_label_prediction_dir, str(claim_nr) + '.json') self.save_dict(claim_dict, path_save)