Ejemplo n.º 1
0
    def compute_score(self):
        # STAGE 2
        # F1
        # PRECISION
        # RECALL

        # STAGE 3
        # FEVER

        list_claims = []
        for claim_nr in tqdm(range(self.nr_claims)):
            path_claim = os.path.join(self.path_label_prediction_dir,
                                      str(claim_nr) + '.json')
            claim_dict = dict_load_json(path_claim)
            list_claims.append(claim_dict)

        strict_score, acc_score, pr, rec, f1 = fever_score(
            predictions=list_claims, actual=None, max_evidence=5)

        print(strict_score, acc_score, pr, rec, f1)
        self.settings['score_metrics'] = {}
        self.settings['score_metrics']['strict_score'] = strict_score
        self.settings['score_metrics']['acc_score'] = acc_score
        self.settings['score_metrics']['pr'] = pr
        self.settings['score_metrics']['rec'] = rec
        self.settings['score_metrics']['f1'] = f1

        self.save_settings()
Ejemplo n.º 2
0
 def __init__(self, path_settings_dir, file_name = 'settings'):
     
     self.path_settings = os.path.join(path_settings_dir, file_name + '.json')
     
     if os.path.isfile(self.path_settings):
         self.settings_dict = dict_load_json(self.path_settings)
     else:
         self.settings_dict = {}
         self.save_settings()
Ejemplo n.º 3
0
    def __init__(self,
                 path_database_dir,
                 database_name,
                 database_method,
                 input_type,
                 output_type,
                 encoding='utf-8',
                 checks_flag=True):
        # input:
        # - path_database_dir : path of the directory of the database
        # - database_name : name of database without extension
        # - database_method : package/method used to construct database
        # - input_type : format of key in dictionary that should always
        #     be used for the database (e.g.'str', 'int', 'float', etc, )
        # - output_type : format of value of dictionary that is
        #     fixed for the database (e.g. 'str', 'int', 'float', etc)
        # - encoding : encoding for characters
        # - check_flags :

        self.path_database_dir = path_database_dir
        self.database_name = database_name
        self.database_method = database_method
        self.input_type = input_type
        self.output_type = output_type
        self.encoding = encoding
        self.checks_flag = checks_flag

        self.path_settings = os.path.join(
            path_database_dir,
            'settings_' + database_name + '_' + self.database_method + '.json')

        mkdir_if_not_exist(path_database_dir)

        self.settings_keys = [
            'database_method', 'input_type', 'output_type', 'encoding'
        ]
        self.settings_values = [
            self.database_method, self.input_type, self.output_type,
            self.encoding
        ]

        if os.path.isfile(self.path_settings):
            settings = dict_load_json(self.path_settings)
            if len(self.settings_keys) == len(settings['settings_keys']):
                for i in range(len(self.settings_keys)):
                    if settings['settings_values'][i] != self.settings_values[i] or settings['settings_keys'][i] != \
                            self.settings_keys[i]:
                        raise ValueError(
                            'saved settings dictionary does not correspond to the settings passed for this database'
                        )
            else:
                raise ValueError(
                    'saved settings dictionary does not correspond to the settings passed for this database'
                )
        else:
            self.settings = {}

            for i in range(len(self.settings_keys)):
                key = self.settings_keys[i]
                value = self.settings_values[i]
                self.settings[key] = value

            self.settings['settings_keys'] = self.settings_keys
            self.settings['settings_values'] = self.settings_values

            self.save_settings()

        if self.database_method == 'lsm':
            self.path_database = os.path.join(path_database_dir,
                                              database_name + '.ldb')
            self.db = LSM(self.path_database)

        elif self.database_method == 'json':
            # only allows data types that can be converted to string
            self.path_database = os.path.join(path_database_dir,
                                              database_name + '.json')
            list_allowed_types = ['string', 'int', 'float']
            if self.input_type not in list_allowed_types:
                raise ValueError('input type not in allowed list',
                                 self.input_type)

            if self.output_type not in list_allowed_types:
                raise ValueError('output type not in allowed list',
                                 self.output_type)

            if os.path.isfile(self.path_database):
                print('load database at: ' + self.path_database)
                self.db = database_load_json(self.path_database, self.encoding)
            else:
                self.db = {}
        else:
            raise ValueError('database_method is not in options',
                             self.database_method)
Ejemplo n.º 4
0
 def load_dict(self, path):
     return dict_load_json(path)
Ejemplo n.º 5
0
    def document_retrieval(self):
        #         claim_nr = 12
        #         line_nr = 0
        #         nr_in_doc_selected_list = 0
        word_dict = pickle.load(open(self.path_word_dict_stage_3, "rb"))

        for claim_nr in tqdm(range(self.nr_claims)):
            path_claim = os.path.join(self.path_dir_doc_selected,
                                      str(claim_nr) + '.json')
            claim_dict = dict_load_json(path_claim)
            claim = Claim(claim_dict)
            claim_text = claim.claim
            # === process word tags and word list === #
            tag_list_claim, word_list_claim = get_word_tag_list_from_text(
                text_str=claim_text,
                nlp=nlp,
                method_tokenization_str=method_tokenization)

            for doc_nr in claim_dict['docs_selected']:
                line_list = wiki_database.get_lines_from_id(doc_nr)
                nr_lines = len(line_list)
                for line_nr in range(nr_lines):
                    line_text = line_list[line_nr]

                    # === process word tags and word list === #
                    tag_list_line, word_list_line = get_word_tag_list_from_text(
                        text_str=line_text,
                        nlp=nlp,
                        method_tokenization_str=method_tokenization)

                    if 'document_retrieval' not in claim_dict:
                        claim_dict['document_retrieval'] = {}
                    if str(doc_nr) not in claim_dict['document_retrieval']:
                        claim_dict['document_retrieval'][str(doc_nr)] = {}

                    if str(line_nr) not in claim_dict['document_retrieval'][
                            str(doc_nr)]:
                        claim_dict['document_retrieval'][str(doc_nr)][str(
                            line_nr)] = {}

                    if 'claim' not in claim_dict['document_retrieval'][str(
                            doc_nr)][str(line_nr)]:
                        claim_dict['document_retrieval'][str(doc_nr)][str(
                            line_nr)]['claim'] = {}

                    if 'document' not in claim_dict['document_retrieval'][str(
                            doc_nr)][str(line_nr)]:
                        claim_dict['document_retrieval'][str(doc_nr)][str(
                            line_nr)]['document'] = {}

                    claim_dict['document_retrieval'][str(doc_nr)][
                        str(line_nr
                            )]['claim']['tag_list'] = [17] + tag_str_2_id_list(
                                tag_list_claim, self.tag_2_id_dict) + [17]
                    claim_dict['document_retrieval'][str(doc_nr)][str(
                        line_nr)]['claim']['word_list'] = word_list_2_id_list(
                            ["_BOS_"] + word_list_claim + ["_EOS_"], word_dict)
                    claim_dict['document_retrieval'][str(doc_nr)][str(
                        line_nr
                    )]['document']['tag_list'] = [17] + tag_str_2_id_list(
                        tag_list_line, self.tag_2_id_dict) + [17]
                    claim_dict['document_retrieval'][str(doc_nr)][str(
                        line_nr
                    )]['document']['word_list'] = word_list_2_id_list(
                        ["_BOS_"] + word_list_line + ["_EOS_"], word_dict)

                    ids_document, ids_claim = hypothesis_evidence_2_index(
                        hypothesis=word_list_line,
                        premise=word_list_claim,
                        randomise_flag=False)

                    claim_dict['document_retrieval'][str(doc_nr)][str(
                        line_nr
                    )]['claim']['exact_match_list'] = [0] + ids_claim + [1]
                    claim_dict['document_retrieval'][str(doc_nr)][str(
                        line_nr)]['document']['exact_match_list'] = [
                            0
                        ] + ids_document + [1]

            path_save = os.path.join(self.path_document_retrieval_dir,
                                     str(claim_nr) + '.json')
            self.save_dict(claim_dict, path_save)
Ejemplo n.º 6
0
    def sentence_retrieval(self):
        print('- sentence retrieval: initialise')
        word_dict = pickle.load(open(self.path_word_dict_stage_2, "rb"))

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        checkpoint = torch.load(self.path_stage_2_model)

        vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
        embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
        hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
        num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

        use_oov_flag = 0
        if 'oov' in self.embeddings_settings_sentence_retrieval_list:
            use_oov_flag = 1

        use_pos_tag_flag = 0
        if 'pos' in self.embeddings_settings_sentence_retrieval_list:
            use_pos_tag_flag = 1

        model = ESIM(vocab_size,
                     embedding_dim,
                     hidden_size,
                     num_classes=num_classes,
                     use_pos_tag_flag=use_pos_tag_flag,
                     use_oov_flag=use_oov_flag,
                     device=device).to(device)

        model.load_state_dict(checkpoint["model"])

        model.eval()

        print('- sentence retrieval: iterate through claims')
        for claim_nr in tqdm(range(self.nr_claims)):
            path_claim = os.path.join(self.path_document_retrieval_dir,
                                      str(claim_nr) + '.json')
            claim_dict = dict_load_json(path_claim)

            list_prob = []
            list_doc_nr = []
            list_line_nr = []

            for doc_nr in claim_dict['document_retrieval']:
                for line_nr in claim_dict['document_retrieval'][doc_nr]:
                    if 'sentence_retrieval' not in claim_dict:
                        claim_dict['sentence_retrieval'] = {}
                    if doc_nr not in claim_dict['sentence_retrieval']:
                        claim_dict['sentence_retrieval'][doc_nr] = {}
                    if line_nr not in claim_dict['sentence_retrieval'][doc_nr]:
                        claim_dict['sentence_retrieval'][doc_nr][line_nr] = {}

                    prob = compute_prob_stage_2(model, claim_dict, doc_nr,
                                                line_nr, device)
                    claim_dict['sentence_retrieval'][doc_nr][line_nr][
                        'prob'] = prob

                    list_doc_nr.append(doc_nr)
                    list_line_nr.append(line_nr)
                    list_prob.append(prob)

            sorted_list_doc_nr = sort_list(list_doc_nr, list_prob)[-5:]
            sorted_list_line_nr = sort_list(list_line_nr, list_prob)[-5:]
            sorted_list_prob = sort_list(list_prob, list_prob)[-5:]
            claim_dict['sentence_retrieval'][
                'doc_nr_list'] = sorted_list_doc_nr
            claim_dict['sentence_retrieval'][
                'line_nr_list'] = sorted_list_line_nr
            claim_dict['sentence_retrieval']['prob_list'] = sorted_list_prob

            claim_dict['predicted_evidence'] = []
            for i in range(len(sorted_list_doc_nr)):
                doc_nr = sorted_list_doc_nr[i]
                title = wiki_database.get_title_from_id(int(doc_nr))
                line_nr = int(sorted_list_line_nr[i])
                claim_dict['predicted_evidence'].append([title, line_nr])

            path_save = os.path.join(self.path_sentence_retrieval_dir,
                                     str(claim_nr) + '.json')
            self.save_dict(claim_dict, path_save)
Ejemplo n.º 7
0
    def label_prediction(self):
        print('- label prediction: initialise')
        word_dict = pickle.load(open(self.path_word_dict_stage_3, "rb"))

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        checkpoint = torch.load(self.path_stage_3_model)

        vocab_size = checkpoint["model"]["_word_embedding.weight"].size(0)
        embedding_dim = checkpoint["model"]['_word_embedding.weight'].size(1)
        hidden_size = checkpoint["model"]["_projection.0.weight"].size(0)
        num_classes = checkpoint["model"]["_classification.4.weight"].size(0)

        use_oov_flag = 0
        if 'oov' in self.embeddings_settings_label_prediction_list:
            use_oov_flag = 1

        use_pos_tag_flag = 0
        if 'pos' in self.embeddings_settings_label_prediction_list:
            use_pos_tag_flag = 1

        model = ESIM(vocab_size,
                     embedding_dim,
                     hidden_size,
                     num_classes=num_classes,
                     use_pos_tag_flag=use_pos_tag_flag,
                     use_oov_flag=use_oov_flag,
                     device=device).to(device)

        model.load_state_dict(checkpoint["model"])

        model.eval()

        print('- label prediction: iterate through claims')
        for claim_nr in tqdm(range(self.nr_claims)):
            path_claim = os.path.join(self.path_sentence_retrieval_dir,
                                      str(claim_nr) + '.json')
            claim_dict = dict_load_json(path_claim)

            prob_list = []
            prob_list_supported = []
            prob_list_refuted = []
            for i in range(len(
                    claim_dict['sentence_retrieval']['doc_nr_list'])):
                doc_nr = claim_dict['sentence_retrieval']['doc_nr_list'][i]
                line_nr = claim_dict['sentence_retrieval']['line_nr_list'][i]
                if doc_nr in claim_dict['document_retrieval']:
                    if line_nr in claim_dict['document_retrieval'][doc_nr]:
                        prob = compute_prob_stage_3(model, claim_dict, doc_nr,
                                                    line_nr, device)
                        prob_list.append(prob)
                        prob_list_supported.append(prob[2])
                        prob_list_refuted.append(prob[1])
                    else:
                        print('line_nr not in list', line_nr)
                else:
                    print('doc_nr not in list', doc_nr)
            if max(prob_list_supported) > 0.5:
                claim_dict['predicted_label'] = 'SUPPORTS'
            elif max(prob_list_refuted) > 0.5:
                claim_dict['predicted_label'] = 'REFUTES'
            else:
                claim_dict['predicted_label'] = 'NOT ENOUGH INFO'

            path_save = os.path.join(self.path_label_prediction_dir,
                                     str(claim_nr) + '.json')
            self.save_dict(claim_dict, path_save)
Ejemplo n.º 8
0
    def __init__(self,
                 wiki_database,
                 nlp,
                 path_stage_2_model,
                 path_stage_3_model,
                 path_dir_doc_selected,
                 method_tokenization,
                 path_base_dir,
                 path_word_dict_stage_2,
                 path_word_dict_stage_3,
                 embeddings_settings_sentence_retrieval_list=[],
                 embeddings_settings_label_prediction_list=[]):
        # === process inputs === #
        self.path_stage_2_model = path_stage_2_model
        self.path_stage_3_model = path_stage_3_model
        self.nlp = nlp
        self.path_dir_doc_selected = path_dir_doc_selected
        self.method_tokenization = method_tokenization
        self.path_base_dir = path_base_dir
        self.path_word_dict_stage_2 = path_word_dict_stage_2
        self.path_word_dict_stage_3 = path_word_dict_stage_3
        self.embeddings_settings_sentence_retrieval_list = embeddings_settings_sentence_retrieval_list
        self.embeddings_settings_label_prediction_list = embeddings_settings_label_prediction_list

        # === paths === #
        self.path_document_retrieval_dir = os.path.join(
            path_base_dir,
            get_file_name_from_variable_list(['document_retrieval']))
        self.path_sentence_retrieval_dir = os.path.join(
            path_base_dir, 'sentence_retrieval')
        self.path_label_prediction_dir = os.path.join(path_base_dir,
                                                      'label_prediction')

        for embeddings_setting in embeddings_settings_sentence_retrieval_list:
            self.path_sentence_retrieval_dir = get_file_name_from_variable_list(
                [self.path_sentence_retrieval_dir, embeddings_setting])

        for embeddings_setting in embeddings_settings_label_prediction_list:
            self.path_label_prediction_dir = get_file_name_from_variable_list(
                [self.path_label_prediction_dir, embeddings_setting])

        if not os.path.isdir(self.path_base_dir):
            os.makedirs(self.path_base_dir)

        self.path_settings = os.path.join(self.path_base_dir, 'settings.json')

        if os.path.isfile(self.path_settings):
            self.settings = dict_load_json(self.path_settings)
        else:
            self.settings = {}

        if 'nr_claims' not in self.settings:
            self.settings['nr_claims'] = self.nr_files_in_dir(
                self.path_dir_doc_selected)
            self.save_settings()

        self.nr_claims = self.settings['nr_claims']
        self.nr_claims = 19998
        #         self.nr_claims = 100
        print('nr claims:', self.nr_claims)

        # === process === #
        self.tag_2_id_dict = get_tag_2_id_dict_unigrams()

        if not os.path.isdir(self.path_document_retrieval_dir):
            os.makedirs(self.path_document_retrieval_dir)
            self.document_retrieval()

        if not os.path.isdir(self.path_sentence_retrieval_dir):
            os.makedirs(self.path_sentence_retrieval_dir)
            self.sentence_retrieval()

        if not os.path.isdir(self.path_label_prediction_dir):
            os.makedirs(self.path_label_prediction_dir)
            self.label_prediction()

        self.compute_score()