Exemplo n.º 1
0
    def __call__(self, doc, **kwargs):
        output_set_name = kwargs.get("outputASName", "GATEML")
        doc_text = doc.text
        current_gate_file_name = doc.features['gate.SourceURL']
        current_gate_file_base_name = os.path.basename(current_gate_file_name)

        #print(doc._name)
        workingSet = doc.annset(self.workingSet)
        config = {'TARGET': {'labels': self.mm.target_labels}}

        test_dataIter = GateReader(postProcessor=self.readerPostProcessor,
                                   config=config)

        outputType = 'MLpred'

        if self.instanceType:
            outputType = self.instanceType

            instanceSet = workingSet.with_type([self.instanceType])
            for instanceAnno in instanceSet:
                current_instance_text = doc_text[instanceAnno.
                                                 start:instanceAnno.end]
                if self.targetFeature:
                    current_instance_target_feature = instanceAnno.features[
                        self.targetFeature]
                else:
                    ### add a dummy target
                    current_instance_target_feature = self.mm.target_labels[0]
                test_dataIter.addSample(current_instance_text,
                                        current_instance_target_feature,
                                        anno_start=instanceAnno.start,
                                        anno_end=instanceAnno.end)
        else:
            current_instance_text = doc_text
            current_instance_target_feature = self.mm.target_labels[0]
            if self.targetFile:
                if current_gate_file_base_name in self.target_dict:
                    current_instance_target_feature = self.target_dict[
                        current_gate_file_base_name]
            test_dataIter.addSample(current_instance_text,
                                    current_instance_target_feature,
                                    anno_start=0,
                                    anno_end=len(current_instance_text))

        test_dataIter._reset_iter()

        apply_output_dict = self.mm.apply(test_dataIter)
        output_set = doc.annset(output_set_name)
        output_set.clear()
        #print(apply_output_dict['all_cls_att'])

        test_dataIter.postProcessor.postProcessMethod = 'postProcess4GATEapply'
        for each_sample_id, dataIterItem in enumerate(test_dataIter):
            each_sample = dataIterItem[0]
            bert_tokenized = dataIterItem[1]
            pred_score = apply_output_dict['all_prediction'][each_sample_id]
            pred_label_string = apply_output_dict['all_pred_label_string'][
                each_sample_id]
            cls_att = apply_output_dict['all_cls_att'][each_sample_id]
            anno_start = each_sample['anno_start']
            anno_end = each_sample['anno_end']
            output_feature_map = {
                'pred_score': pred_score,
                self.targetFeature: pred_label_string
            }
            output_set.add(anno_start, anno_end, outputType,
                           output_feature_map)

            #recon_token_list, topn_indices, topn_values = single_att_reconstruction(bert_tokenized, cls_att)
            #off_set_dict = construct_offset_id(doc_text, recon_token_list)

            #print(len(recon_token_list), len(token_offset_list))

            if self.resultsExportFile:
                result_export_line = current_gate_file_base_name + '\t' + str(
                    anno_start) + '\t' + str(
                        anno_end) + '\t' + pred_label_string + '\t' + doc_text[
                            anno_start:anno_end] + '\n'
                self.f_results_export.write(result_export_line)

            if not self.instanceType and self.targetFile and self.target2GateType:
                output_feature_map = {
                    self.targetFeature: each_sample['target']
                }
                output_set.add(anno_start, anno_end, self.target2GateType,
                               output_feature_map)

            ###export attention
            #for att_id, att_word_index in enumerate(topn_indices):
            #    att_score = topn_values[att_id]
            #    if att_word_index in off_set_dict:
            #        att_feature_map = {'score':str(att_score)}
            #        print(off_set_dict[att_word_index][0], off_set_dict[att_word_index][1], len(doc_text))
            #        output_set.add(off_set_dict[att_word_index][0], off_set_dict[att_word_index][1], 'attentions', att_feature_map)

        test_dataIter.postProcessor.postProcessMethod = 'postProcess4Model'
Exemplo n.º 2
0
class GateMLTrain:
    def __init__(self):
        self.script_path = os.path.abspath(__file__)
        self.processorLoger = getLogger('processorLoger')

    def start(self, **kwargs):
        #print(kwargs)
        #self.all_doc = []
        readerPostProcessor = BertPostProcessor(x_fields=['text'], y_field='target')
        self.train_dataIter = GateReader(postProcessor=readerPostProcessor, shuffle=True)

        self.workingSet = kwargs.get('workingSet', '')
        self.instanceType = kwargs.get('instanceType', None)
        self.targetType = kwargs.get('targetType', None)
        self.targetFeature = kwargs.get('targetFeature', None)
        self.targetFile = kwargs.get('targetFile', None)
        self.gpu = str_to_bool(kwargs.get('gpu', 'False'))
        self.model_path = kwargs.get('model_path')

        if self.targetFile:
            self.target_dict = {}
            with open(self.targetFile, 'r') as ft:
                for each_line_id, each_line in enumerate(ft):
                    if each_line_id == 0:
                        self.file_suffix = each_line.strip()
                    else:
                        line_tok = each_line.split('\t')
                        self.target_dict[line_tok[0]] = line_tok[1].strip()

    def finish(self, **kwargs):
        self.train_dataIter.finaliseReader()
        print(len(self.train_dataIter))

        val_dataIter = None
        print(self.train_dataIter.target_labels)
        
        dummy_config = {'MODEL':{'n_classes':len(self.train_dataIter.target_labels)}}

        self.mm = ModelManager(gpu=self.gpu, config=dummy_config)
        self.mm.genPreBuildModel()


        if 'splitValidation' in kwargs:
            self.train_dataIter, val_dataIter = self.mm.splitValidation(self.train_dataIter, val_split=float(kwargs.get('splitValidation')))

        self.mm.train(self.train_dataIter, save_path=self.model_path, valDataIter=val_dataIter, earlyStopping=True, patience=5)

    def __call__(self, doc, **kwargs):
        doc_text = doc.text
        #print(doc._name)
        #print(doc.features)
        current_gate_file_name = doc.features['gate.SourceURL']
        current_gate_file_base_name = os.path.basename(current_gate_file_name)
        #print(current_gate_file_base_name)

        workingSet = doc.annset(self.workingSet)

        if self.instanceType:
            instanceSet = workingSet.with_type([self.instanceType])
            for instanceAnno in instanceSet:
                #print(instanceAnno)
                #print(instanceAnno.start)
                #print(instanceAnno.end)
                current_instance_text = doc_text[instanceAnno.start:instanceAnno.end]
                current_instance_target_feature = instanceAnno.features[self.targetFeature]
                #print(current_instance_text, current_instance_target_feature)
                self.train_dataIter.addSample(current_instance_text, current_instance_target_feature)
        elif self.targetFile:
            current_instance_text = doc_text
            if current_gate_file_base_name in self.target_dict:
                current_instance_target_feature = self.target_dict[current_gate_file_base_name] 
                self.train_dataIter.addSample(current_instance_text, current_instance_target_feature)
            else:
                infomessage = 'no target found discard '+current_gate_file_name
                self.processorLoger.info(infomessage)
Exemplo n.º 3
0
    def __call__(self, doc, **kwargs):
        output_set_name = kwargs.get("outputASName", "GATEML")
        doc_text = doc.text
        current_gate_file_name = doc.features['gate.SourceURL']
        current_gate_file_base_name = os.path.basename(current_gate_file_name)

        #print(doc._name)
        workingSet = doc.annset(self.workingSet)
        config = {'TARGET': {'labels': self.mm.target_labels}}

        test_dataIter = GateReader(postProcessor=self.readerPostProcessor,
                                   config=config)

        outputType = 'MLpred'

        if self.instanceType:
            outputType = self.instanceType

            instanceSet = workingSet.with_type([self.instanceType])
            for instanceAnno in instanceSet:
                current_instance_text = doc_text[instanceAnno.
                                                 start:instanceAnno.end]
                if self.targetFeature:
                    current_instance_target_feature = instanceAnno.features[
                        self.targetFeature]
                else:
                    ### add a dummy target
                    current_instance_target_feature = self.mm.target_labels[0]
                test_dataIter.addSample(current_instance_text,
                                        current_instance_target_feature,
                                        anno_start=instanceAnno.start,
                                        anno_end=instanceAnno.end)
        else:
            current_instance_text = doc_text
            current_instance_target_feature = self.mm.target_labels[0]
            if self.targetFile:
                if current_gate_file_base_name in self.target_dict:
                    current_instance_target_feature = self.target_dict[
                        current_gate_file_base_name]
            test_dataIter.addSample(current_instance_text,
                                    current_instance_target_feature,
                                    anno_start=0,
                                    anno_end=len(current_instance_text))

        test_dataIter._reset_iter()

        apply_output_dict = self.mm.apply(test_dataIter)
        output_set = doc.annset(output_set_name)
        output_set.clear()

        test_dataIter.goPoseprocessor = False
        for each_sample_id, each_sample in enumerate(test_dataIter):
            pred_score = apply_output_dict['all_prediction'][each_sample_id]
            pred_label_string = apply_output_dict['all_pred_label_string'][
                each_sample_id]
            anno_start = each_sample['anno_start']
            anno_end = each_sample['anno_end']
            output_feature_map = {
                'pred_score': pred_score,
                self.targetFeature: pred_label_string
            }
            output_set.add(anno_start, anno_end, outputType,
                           output_feature_map)
            if self.resultsExportFile:
                result_export_line = current_gate_file_base_name + '\t' + str(
                    anno_start) + '\t' + str(
                        anno_end) + '\t' + pred_label_string + '\t' + doc_text[
                            anno_start:anno_end] + '\n'
                self.f_results_export.write(result_export_line)