コード例 #1
0
    def sentence_retrieval(self):
        print('- sentence retrieval: initialise')
        word_dict = pickle.load(open(self.path_word_dict_stage_2, "rb"))

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        checkpoint = torch.load(self.path_stage_2_model)

        vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
        embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
        hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
        num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

        use_oov_flag = 0
        if 'oov' in self.embeddings_settings_sentence_retrieval_list:
            use_oov_flag = 1

        use_pos_tag_flag = 0
        if 'pos' in self.embeddings_settings_sentence_retrieval_list:
            use_pos_tag_flag = 1

        model = ESIM(vocab_size,
                     embedding_dim,
                     hidden_size,
                     num_classes=num_classes,
                     use_pos_tag_flag=use_pos_tag_flag,
                     use_oov_flag=use_oov_flag,
                     device=device).to(device)

        model.load_state_dict(checkpoint["model"])

        model.eval()

        print('- sentence retrieval: iterate through claims')
        for claim_nr in tqdm(range(self.nr_claims)):
            path_claim = os.path.join(self.path_document_retrieval_dir,
                                      str(claim_nr) + '.json')
            claim_dict = dict_load_json(path_claim)

            list_prob = []
            list_doc_nr = []
            list_line_nr = []

            for doc_nr in claim_dict['document_retrieval']:
                for line_nr in claim_dict['document_retrieval'][doc_nr]:
                    if 'sentence_retrieval' not in claim_dict:
                        claim_dict['sentence_retrieval'] = {}
                    if doc_nr not in claim_dict['sentence_retrieval']:
                        claim_dict['sentence_retrieval'][doc_nr] = {}
                    if line_nr not in claim_dict['sentence_retrieval'][doc_nr]:
                        claim_dict['sentence_retrieval'][doc_nr][line_nr] = {}

                    prob = compute_prob_stage_2(model, claim_dict, doc_nr,
                                                line_nr, device)
                    claim_dict['sentence_retrieval'][doc_nr][line_nr][
                        'prob'] = prob

                    list_doc_nr.append(doc_nr)
                    list_line_nr.append(line_nr)
                    list_prob.append(prob)

            sorted_list_doc_nr = sort_list(list_doc_nr, list_prob)[-5:]
            sorted_list_line_nr = sort_list(list_line_nr, list_prob)[-5:]
            sorted_list_prob = sort_list(list_prob, list_prob)[-5:]
            claim_dict['sentence_retrieval'][
                'doc_nr_list'] = sorted_list_doc_nr
            claim_dict['sentence_retrieval'][
                'line_nr_list'] = sorted_list_line_nr
            claim_dict['sentence_retrieval']['prob_list'] = sorted_list_prob

            claim_dict['predicted_evidence'] = []
            for i in range(len(sorted_list_doc_nr)):
                doc_nr = sorted_list_doc_nr[i]
                title = wiki_database.get_title_from_id(int(doc_nr))
                line_nr = int(sorted_list_line_nr[i])
                claim_dict['predicted_evidence'].append([title, line_nr])

            path_save = os.path.join(self.path_sentence_retrieval_dir,
                                     str(claim_nr) + '.json')
            self.save_dict(claim_dict, path_save)
コード例 #2
0
    def label_prediction(self):
        print('- label prediction: initialise')
        word_dict = pickle.load(open(self.path_word_dict_stage_3, "rb"))

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        checkpoint = torch.load(self.path_stage_3_model)

        vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
        embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
        hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
        num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

        use_oov_flag = 0
        if 'oov' in self.embeddings_settings_label_prediction_list:
            use_oov_flag = 1

        use_pos_tag_flag = 0
        if 'pos' in self.embeddings_settings_label_prediction_list:
            use_pos_tag_flag = 1

        model = ESIM(vocab_size,
                     embedding_dim,
                     hidden_size,
                     num_classes=num_classes,
                     use_pos_tag_flag=use_pos_tag_flag,
                     use_oov_flag=use_oov_flag,
                     device=device).to(device)

        model.load_state_dict(checkpoint["model"])

        model.eval()

        print('- label prediction: iterate through claims')
        for claim_nr in tqdm(range(self.nr_claims)):
            path_claim = os.path.join(self.path_sentence_retrieval_dir,
                                      str(claim_nr) + '.json')
            claim_dict = dict_load_json(path_claim)

            prob_list = []
            prob_list_supported = []
            prob_list_refuted = []
            for i in range(len(
                    claim_dict['sentence_retrieval']['doc_nr_list'])):
                doc_nr = claim_dict['sentence_retrieval']['doc_nr_list'][i]
                line_nr = claim_dict['sentence_retrieval']['line_nr_list'][i]
                if doc_nr in claim_dict['document_retrieval']:
                    if line_nr in claim_dict['document_retrieval'][doc_nr]:
                        prob = compute_prob_stage_3(model, claim_dict, doc_nr,
                                                    line_nr, device)
                        prob_list.append(prob)
                        prob_list_supported.append(prob[2])
                        prob_list_refuted.append(prob[1])
                    else:
                        print('line_nr not in list', line_nr)
                else:
                    print('doc_nr not in list', doc_nr)
            if max(prob_list_supported) > 0.5:
                claim_dict['predicted_label'] = 'SUPPORTS'
            elif max(prob_list_refuted) > 0.5:
                claim_dict['predicted_label'] = 'REFUTES'
            else:
                claim_dict['predicted_label'] = 'NOT ENOUGH INFO'

            path_save = os.path.join(self.path_label_prediction_dir,
                                     str(claim_nr) + '.json')
            self.save_dict(claim_dict, path_save)