コード例 #1
0
    def predict(self,
                input_file,
                output_file,
                test_file="../data/results/result_log2.csv"):
        """
        输入必须有序号,包含在passage中,
        例如:`1,凑单买的面膜,很好用,买给妈妈的`
        :param passage:
        :return:
        """
        with open(input_file, encoding='utf-8') as infile:
            items = []
            for row in infile:
                row = json.loads(row)
                item = {}
                item['id'] = row['id']
                item['passage'] = row['passage']
                item['entity'] = row['question']
                item['label'] = row['label']
                items.append(item)

        outputs = []
        print_items = []
        count = 0
        for item in tqdm(items, desc="正在预测", ncols=80):
            passage = item['passage']
            entity_list = item['entity']
            output_item = {}
            output_item['id'] = item['id']
            output_item['passage'] = item['passage']
            output_item['entity'] = item['entity']
            output_item['label'] = item['label']
            entity = item['entity']
            entity_label_ensemble = []
            for model in self.cls_entity_predictor_pool:
                entity_label_ensemble.append(
                    model.predict(passage[:512 - 3 - 2 - len(entity)], entity))
            entity_label = cls_entity_ensemble(entity_label_ensemble)
            if entity_label != item['label']:
                output_item['p_label'] = entity_label
                outputs.append(output_item)

        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("id,passage,question,label,p_label\n")
            for item in outputs:
                f.write(item['id'] + ',' + str(item['passage']) + ',' +
                        str(item['entity']) + ',' + str(item['label']) + ',' +
                        str(item['p_label']))
                f.write("\n")
コード例 #2
0
    def predict(self, input_file, output_file):
        """
        输入必须有序号,包含在passage中,
        例如:`1,凑单买的面膜,很好用,买给妈妈的`
        :param passage:
        :return:
        """
        with open(input_file, encoding='utf-8') as csvfile:
            reader = csv.DictReader(csvfile)
            title = reader.fieldnames
            items = []
            for row in reader:
                item = {}
                item['id'] = row['\ufeffid']
                item['passage'] = row['text'][:450]
                item['entity_list'] = []
                item['question_list'] = []
                for entity in row['entity'].split(";"):
                    item['entity_list'].append(entity)
                    item['question_list'].append(entity + "不好")
                items.append(item)

        outputs = []
        for item in tqdm(items, desc="正在预测", ncols=80):
            # item =json.loads(item)
            passage = item['passage']
            negative_entity_list = []
            for i, question in enumerate(item['question_list']):
                label_ensemble = []
                for model in self.cls_entity_predictor_pool:
                    # print(passage, question)
                    label_ensemble.append(model.predict(passage, question))
                label = cls_entity_ensemble(label_ensemble)
                if label == "正类":
                    negative_entity_list.append(item['entity_list'][i])
            output_item = {}
            output_item['id'] = item['id']
            output_item['entity'] = negative_entity_list
            output_item['negative'] = 0 if len(
                negative_entity_list) == 0 else 1
            outputs.append(output_item)

        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("id,negative,key_entity\n")
            for item in outputs:
                f.write(item['id'] + ',' + str(item['negative']) + ',' +
                        ';'.join(item['entity_list']))
                f.write("\n")
コード例 #3
0
    def predict(self, input_file, output_file):
        """
        输入必须有序号,包含在passage中,
        例如:`1,凑单买的面膜,很好用,买给妈妈的`
        :param passage:
        :return:
        """
        with open(input_file, encoding='utf-8') as infile:
            items = []
            for row in infile:
                row = json.loads(row)
                item = {}
                item['id'] = row['id']
                item['passage'] = row['passage']
                item['entity'] = []
                for entity in row['entity']:
                    item['entity'].append(entity)
                # print(item)
                items.append(item)

        outputs = []
        count = 0
        for item in tqdm(items, desc="正在预测", ncols=80):
            # item =json.loads(item)
            passage = item['passage']
            # entity_list = item['entity']
            # # 去除重复 entity
            # pop_index = []
            # for i in range(len(entity_list)):
            #     for j in range(i + 1, len(entity_list)):
            #         if entity_list[i] in entity_list[j] or entity_list[j] in entity_list[i]:
            #             if passage.count(entity_list[i]) == passage.count(entity_list[j]):
            #                 # print(item['id'], passage,
            #                 #       entity_list[i], entity_list[j],
            #                 #       passage.count(entity_list[i]),passage.count(entity_list[j]))
            #                 count += 1
            #                 if entity_list[i] in entity_list[j]:
            #                     pop_index.append(i)
            #                 else:
            #                     pop_index.append(j)
            #
            # entity_list = [entity_list[i] for i in range(len(entity_list)) if (i not in pop_index)]

            output_item = {}
            output_item['id'] = item['id']
            output_item['passage'] = item['passage']
            output_item['entity'] = item['entity']
            output_item['negative'] = 0

            sentence_label_ensemble = []
            for model in self.cls_sentence_predictor_pool:
                sentence_label_ensemble.append(
                    model.predict(passage[:512 - 3 - 2], ""))
            sentence_label = cls_entity_ensemble(sentence_label_ensemble)
            if sentence_label == "正类":
                # 判断 negative_entity
                # negative_entity_list = []
                # for i, entity in enumerate(entity_list):
                #     entity_label_ensemble = []
                #     for model in self.cls_entity_predictor_pool:
                #         entity_label_ensemble.append(model.predict(passage[:512 - 3 - 2 - len(entity)], entity))
                #     entity_label = cls_entity_ensemble(entity_label_ensemble)
                #     if entity_label == "正类":
                #         negative_entity_list.append(entity)
                # print(negative_entity_list)

                # output_item['entity'] = negative_entity_list
                output_item['negative'] = 1

            outputs.append(output_item)
        # print("去重 entity 的个数", count)

        # with open(output_file, 'w+', encoding='utf-8') as f:
        #     # f.write("id,negative,key_entity\n")
        #     for item in outputs:
        #         # f.write(item['id'] + ',' + str(item['negative']) + ',' + ';'.join(item['entity']))
        #         # f.write("\n")
        #         json.dump(item, f)
        write_file(outputs, output_file)
コード例 #4
0
    def predict(self, input_file, output_file):
        """
        输入必须有序号,包含在passage中,
        例如:`1,凑单买的面膜,很好用,买给妈妈的`
        :param passage:
        :return:
        """
        with open(input_file, encoding='utf-8') as infile:
            items=[]
            for row in infile:
                row = json.loads(row)
                item ={}
                item['id']=row['id']
                item['passage'] = row['passage']
                item['entity'] =[]
                for entity in row['entity']:
                    item['entity'].append(entity)
                # print(item)
                items.append(item)
                
        batch_items_list = []
        i = 0
        while i < len(items):
            batch_items_list.append(items[i:i+batch_size])
            i += batch_size

        outputs=[]
        count = 0
        for batch_items in tqdm(batch_items_list, desc="正在预测", ncols=80):
            # item =json.loads(item)
            batch_inputs_sentence = []
            batch_inputs_entity = []
            batch_output_item = []
            batch_entity_list = []
            for item in batch_items:
                passage=item['passage']
                batch_inputs_sentence.append((passage[:512-3-2],""))
                entity_list = item['entity']
                ## 去除重复 entity
                pop_index = []
                for i in range(len(entity_list)):
                    for j in range(i+1, len(entity_list)):
                        if entity_list[i] in entity_list[j] or entity_list[j] in entity_list[i]:
                            if passage.count(entity_list[i]) == passage.count(entity_list[j]):
                                # print(item['id'], passage, 
                                #       entity_list[i], entity_list[j], 
                                #       passage.count(entity_list[i]),passage.count(entity_list[j]))
                                count += 1
                                if entity_list[i] in entity_list[j]:
                                    pop_index.append(i)
                                else:
                                    pop_index.append(j)      
                entity_list = [entity_list[i] for i in range(len(entity_list)) if (i not in pop_index)]
                batch_entity_list.append(entity_list) 

                inputs_entity =[]
                for entity in entity_list:
                    inputs_entity.append((passage[:512-3-2-len(entity)], entity))
                batch_inputs_entity.append(inputs_entity)

                output_item ={}
                output_item['id'] = item['id']
                output_item['entity'] = []
                output_item['negative'] = 0 
                batch_output_item.append(output_item)
            
            batch_sentence_label_ensemble = []
            for model in self.cls_sentence_predictor_pool:
                batch_sentence_label_ensemble.append(model.batch_predict(batch_inputs_sentence))
            batch_sentence_label_ensemble = np.array(batch_sentence_label_ensemble).T.tolist()

            for i, sentence_label_ensemble in enumerate(batch_sentence_label_ensemble):
                sentence_label = cls_entity_ensemble(sentence_label_ensemble)
                
                if sentence_label == "正类":
                    negative_entity_list=[]
                    if len(batch_inputs_entity[i])==0:
                        print(sentence_label,batch_inputs_sentence[i] )
                        break
                    batch_entity_label_ensemble = []
                    for model in self.cls_entity_predictor_pool:
                        batch_entity_label_ensemble.append(model.batch_predict(batch_inputs_entity[i]))
                    batch_entity_label_ensemble = np.array(batch_entity_label_ensemble).T.tolist()
                    for entity_label_ensemble in batch_entity_label_ensemble:
                        entity_label = cls_entity_ensemble(entity_label_ensemble)
                        if entity_label=="正类":
                            negative_entity_list.append(entity)
                    # print(negative_entity_list)
               
                batch_output_item[i]['entity'] = negative_entity_list
                batch_output_item[i]['negative'] = 0 if len(negative_entity_list)==0 else 1
            outputs.extend(batch_output_item)
        print("去重 entity 的个数", count)

        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("id,negative,key_entity\n")
            for item in outputs:
                f.write(item['id'] + ',' + str(item['negative'])+ ',' + ';'.join(item['entity']))
                f.write("\n")
コード例 #5
0
    def predict(self, input_file, output_file, test_file="../data/results/result_log2.csv"):
        """
        输入必须有序号,包含在passage中,
        例如:`1,凑单买的面膜,很好用,买给妈妈的`
        :param passage:
        :return:
        """
        with open(input_file, encoding='utf-8') as infile:
            items = []
            for row in infile:
                row = json.loads(row)
                item = {}
                item['id'] = row['id']
                item['passage'] = row['passage']
                # item['negative'] = row['negative']
                item['negative'] = 1
                item['entity'] = []
                for entity in row['entity']:
                    item['entity'].append(entity)
                # print(item)
                items.append(item)

        outputs = []
        print_items = []
        count = 0
        reverse_count1 = 0
        reverse_count2 = 0
        for item in tqdm(items, desc="正在预测", ncols=80):
            # item =json.loads(item)
            passage = item['passage']
            entity_list = item['entity']
            # 去除重复 entity
            pop_index = []
            for i in range(len(entity_list)):
                for j in range(i + 1, len(entity_list)):
                    if entity_list[i] in entity_list[j] or entity_list[j] in entity_list[i]:
                        if passage.count(entity_list[i]) == passage.count(entity_list[j]):
                            # print(item['id'], passage,
                            #       entity_list[i], entity_list[j],
                            #       passage.count(entity_list[i]),passage.count(entity_list[j]))
                            count += 1
                            if entity_list[i] in entity_list[j]:
                                pop_index.append(i)
                            else:
                                pop_index.append(j)

            entity_list = [entity_list[i] for i in range(len(entity_list)) if (i not in pop_index)]

            output_item = {}
            output_item['id'] = item['id']
            output_item['entity'] = []
            output_item['negative'] = 0

            if item['negative'] == 1:
                # 判断 negative_entity
                negative_entity_list = []
                for i, entity in enumerate(entity_list):
                    entity_label_ensemble = []
                    for model in self.cls_entity_predictor_pool:
                        entity_label_ensemble.append(model.predict(passage[:512 - 3 - 2 - len(entity)], entity))
                    entity_label = cls_entity_ensemble(entity_label_ensemble)
                    if entity_label == "正类":
                        negative_entity_list.append(entity)

                output_item['entity'] = negative_entity_list
                output_item['negative'] = 0 if len(negative_entity_list) == 0 else 1
                if len(negative_entity_list) == 0:
                    output_item['negative'] = 0
                    # print_items.append(item)
                    # reverse_count1 += 1
                else:
                    output_item['negative'] = 1
            # else:
            #     negative_entity_list = []
            #     for i, entity in enumerate(entity_list):
            #         entity_label_ensemble = []
            #         for model in self.cls_entity_predictor_pool:
            #             entity_label_ensemble.append(model.predict(passage[:512 - 3 - 2 - len(entity)], entity))
            #         entity_label = cls_entity_ensemble(entity_label_ensemble)
            #         if entity_label == "正类":
            #             negative_entity_list.append(entity)
            #     if len(negative_entity_list) != 0:
            #         print_items.append(item)
            #         reverse_count2 += 1

            outputs.append(output_item)
        print("去重 entity 的个数", count)
        # print("空entity_list个数", reverse_count1)
        # print("entity_list个数", reverse_count2)

        with open(output_file, 'w', encoding='utf-8') as f:
            f.write("id,negative,key_entity\n")
            for item in outputs:
                f.write(item['id'] + ',' + str(item['negative']) + ',' + ';'.join(item['entity']))
                f.write("\n")

        with open(test_file, 'w', encoding='utf-8') as f:
            for item in print_items:
                f.write(item['id'] + ',' + str(item['passage']) + str(item['negative']) + ',' + ';'.join(item['entity']))
                f.write("\n")