Ejemplo n.º 1
0
    def process(self, dataset, path=None):
        datable = DataTable()

        for item in dataset:
            words = item['words']
            triggers = ['O'] * len(words)
            for event_mention in item['golden-event-mentions']:
                for i in range(event_mention['trigger']['start'],
                               event_mention['trigger']['end']):
                    trigger_type = event_mention['event_type']
                    if i == event_mention['trigger']['start']:
                        triggers[i] = 'B-{}'.format(trigger_type)
                    else:
                        triggers[i] = 'I-{}'.format(trigger_type)
            input_id, attention_mask, segment_id, valid_mask, label_id, label_mask = process(
                words, triggers, self.tokenizer, self.vocabulary,
                self.max_length)
            datable('input_ids', input_id)
            datable('attention_mask', attention_mask)
            datable('segment_ids', segment_id)
            datable('valid_masks', valid_mask)
            datable('label_ids', label_id)
            datable('label_masks', label_mask)

        if path and os.path.exists(path):
            datable.save_table(path)
        return datable
Ejemplo n.º 2
0
    def _load(self, path):
        dataset = DataTable()
        sentence = []
        label = []
        frame = -1
        pos = -1
        with open(path) as f:
            for line in f:
                if len(line) == 0 or line[0] == "\n":
                    if len(sentence) > 0:
                        dataset('sentence', sentence)
                        dataset('label', label)
                        dataset('frame', frame)
                        dataset('pos', pos)
                        sentence = []
                        label = []
                        frame = -1
                        pos = -1
                    continue
                words = line.split('\t')
                sentence.append(words[1])
                element = words[-2].replace('S-', 'B-')
                label.append(element)
                if words[-3] not in '_':
                    pos = len(sentence) - 1
                    frame = words[-3]
                    self.trigger_label_set.add(frame)

                self.argument_label_set.add(element)
            if len(sentence) > 0:
                dataset('sentence', sentence)
                dataset('label', label)
                dataset('frame', frame)
                dataset('pos', pos)
        return dataset
Ejemplo n.º 3
0
 def _load(self, path):
     dataset = DataTable()
     data = load_json(path)
     for item in data:
         ner_label = []
         rc_label = []
         ner_check = []
         rc_check = []
         text = item["text"].split(" ")
         for label in item["triple_list"]:
             subject_word_loc = text.index(label[0])
             relation = label[1]
             object_word_loc = text.index(label[2])
             if subject_word_loc not in ner_check:
                 ner_label.append(
                     [subject_word_loc, subject_word_loc, "None"])
                 ner_check += [subject_word_loc, subject_word_loc, "None"]
             if object_word_loc not in ner_check:
                 ner_label.append(
                     [object_word_loc, object_word_loc, "None"])
                 ner_check += [object_word_loc, object_word_loc, "None"]
             rc_label.append([subject_word_loc, object_word_loc, relation])
         dataset("text", text)
         dataset("ner_label", ner_label)
         dataset("rc_label", rc_label)
     return dataset
Ejemplo n.º 4
0
 def _load(self, path):
     datable = DataTable()
     with open(path) as f:
         for line in f:
             data = json.loads(line)
             text = data['text']
             words = []
             for word in text:
                 words.append(word)
             if 'spo_list' in data:
                 spo_list = data['spo_list']
                 for spo in spo_list:
                     if 'predicate' in spo:
                         relation = spo['predicate']
                         subject = spo['subject']
                         object = spo['object']['@value']
                         subject_pos = get_position(text, subject)
                         object_pos = get_position(text, object)
                         if subject_pos is None or object_pos is None:
                             continue
                         subj_start = subject_pos[0]
                         subj_end = subject_pos[1]
                         obj_start = object_pos[0]
                         obj_end = object_pos[1]
                         self.label_set.add(relation)
                         datable('token', words)
                         datable('relation', relation)
                         datable('subj_start', subj_start)
                         datable('subj_end', subj_end)
                         datable('obj_start', obj_start)
                         datable('obj_end', obj_end)
     return datable
Ejemplo n.º 5
0
 def process(self, datasets):
     datable = DataTable()
     for dataset in tqdm(datasets, desc="Processing"):
         text = dataset['text']
         entities = dataset['entities']
         sentences_boundaries = dataset['sentences_boundaries']
         for sentences_boundary in sentences_boundaries:
             positions = []
             sentence = text[sentences_boundary[0]:sentences_boundary[1]]
             for entity in entities:
                 if entity['boundaries'][0] >= sentences_boundary[0] and \
                         entity['boundaries'][1] <= sentences_boundary[1] and 'entity' in entity['uri']:
                     positions.append(entity['boundaries'])
             words, labels = get_labels(sentence, positions, text,
                                        sentences_boundary)
             input_id, attention_mask, segment_id, valid_mask, label_id, label_mask = process(
                 words, labels, self.tokenizer, self.vocabulary,
                 self.max_length)
             datable('input_ids', input_id)
             datable('attention_masks', attention_mask)
             datable('segment_ids', segment_id)
             datable('valid_mask', valid_mask)
             datable('label_ids', label_id)
             datable('label_masks', label_mask)
     return datable
Ejemplo n.º 6
0
    def process(self, dataset):
        datable = DataTable()
        for i in range(len(dataset)):
            token, relation, subj_start, subj_end, obj_start, obj_end = dataset[
                i]
            label_id = self.vocabulary.to_index(relation)
            item = {
                'token': token,
                'h': {
                    'pos': [subj_start, subj_end + 1]
                },
                't': {
                    'pos': [obj_start, obj_end + 1]
                }
            }
            indexed_tokens, att_mask, pos1, pos2 = self.tokenize(item)
            datable('input_ids', indexed_tokens)
            datable('attention_mask', att_mask)
            datable('pos1', pos1)
            datable('pos2', pos2)
            datable('label_id', label_id)

            datable('input_ids', indexed_tokens)
            datable('attention_mask', att_mask)
            datable('pos1', pos2)
            datable('pos2', pos1)
            datable('label_id', self.vocabulary.to_index('<unk>'))
        return datable
Ejemplo n.º 7
0
 def process(self, dataset):
     datable = DataTable()
     for sentence, label in zip(dataset['sentence'], dataset['label']):
         input_id, attention_mask, segment_id, valid_mask, label_id, label_mask = process(
             sentence, label, self.tokenizer, self.vocabulary,
             self.max_length)
         datable('input_ids', input_id)
         datable('attention_mask', attention_mask)
         datable('segment_ids', segment_id)
         datable('valid_masks', valid_mask)
         datable('label_ids', label_id)
         datable('label_masks', label_mask)
     return datable
Ejemplo n.º 8
0
 def _load(self, path):
     dataset = DataTable()
     with open(path) as f:
         lines = f.readlines()
         for line in lines:
             sample = json.loads(line)
             dataset("content", sample["content"])
             dataset("index", sample["index"])
             dataset("type", sample["type"])
             dataset("args", sample["args"])
             dataset("occur", sample["occur"])
             dataset("triggers", sample["triggers"])
             dataset("id", sample["id"])
     return dataset
Ejemplo n.º 9
0
 def process(self, dataset):
     datable = DataTable()
     print("process data...")
     for text, ner_label, rc_label in tqdm(zip(dataset['text'],
                                               dataset['ner_label'],
                                               dataset['rc_label']),
                                           total=len(dataset['text'])):
         words, ner_labels, rc_labels, bert_len = self.process_item(
             text, ner_label, rc_label)
         datable('words', words)
         datable('ner_labels', ner_labels)
         datable('rc_labels', rc_labels)
         datable('bert_len', bert_len)
     return datable
Ejemplo n.º 10
0
 def process(self, dataset):
     datable = DataTable()
     for data in tqdm(dataset, desc='Processing Data'):
         words = data['words']
         labels = data['labels']
         input_id, attention_mask, segment_id, label_id, label_mask = process(
             words, labels, self.tokenizer, self.vocabulary,
             self.max_length)
         datable('input_ids', input_id)
         datable('attention_mask', attention_mask)
         datable('segment_ids', segment_id)
         datable('label_ids', label_id)
         datable('label_masks', label_mask)
     return datable
Ejemplo n.º 11
0
 def _load(self, path):
     dataset = DataTable()
     with open(path) as f:
         while True:
             line = f.readline()
             if not line:
                 break
             data = json.loads(line)
             dataset('words', data['tokens'])
             dataset('mentions', data['mentions'])
             mentions = data['mentions']
             for mention in mentions:
                 for label in mention['labels']:
                     self.label_set.add(label)
     return dataset
Ejemplo n.º 12
0
 def process(self, dataset):
     datable = DataTable()
     for sentence, label in zip(dataset['sentence'], dataset['label']):
         input_id, attention_mask, segment_id, head_index, label_id, label_mask = process(
             sentence, label, self.tokenizer, self.vocabulary,
             self.max_length)
         if len(input_id) <= self.max_length and len(
                 head_index) <= self.max_length and len(
                     label_id) <= self.max_length:
             datable('input_ids', input_id)
             datable('attention_mask', attention_mask)
             datable('segment_ids', segment_id)
             datable('head_indexes', head_index)
             datable('label_ids', label_id)
             datable('label_masks', label_mask)
     return datable
Ejemplo n.º 13
0
    def process(self, dataset):
        datable = DataTable()
        for i in range(len(dataset)):
            words, mentions, start, end = dataset[i]
            input_id, attention_mask, segment_id, head_index, label_id = \
                self.output(words, mentions, self.tokenizer, self.vocabulary, self.max_length)
            if len(input_id) <= self.max_length and len(
                    head_index) <= self.max_length:
                datable('input_id', input_id)
                datable('attention_mask', attention_mask)
                datable('segment_id', segment_id)
                datable('head_index', head_index)
                datable('label_id', label_id)
                datable('start', start)
                datable('end', end)

        return datable
Ejemplo n.º 14
0
 def process(self, dataset):
     datable = DataTable()
     # add your own process code here
     for sample in zip(
             dataset['ex_id'],
             dataset['left_context'],
             dataset['right_context'],
             dataset['mention'],
             dataset['label'],
     ):
         input_ids, token_type_ids, attention_mask, target = process_ufet(
             sample, self.tokenizer, self.vocabulary, self.max_length)
         datable('input_ids', input_ids)
         datable('token_type_ids', token_type_ids)
         datable('attention_mask', attention_mask)
         datable('target', target)
     return datable
Ejemplo n.º 15
0
 def _load(self, path):
     dataset = DataTable()
     with open(path) as f:
         while True:
             line = f.readline()
             if not line:
                 break
             words = line.split('\t')
             tokens = words[2].split(' ')
             mentions = words[3].strip().split(' ')
             for mention in mentions:
                 self.label_set.add(mention)
             dataset('words', tokens)
             dataset('mentions', mentions)
             dataset('start', int(words[0]))
             dataset('end', int(words[1]))
     return dataset
Ejemplo n.º 16
0
 def process(self, dataset):
     datable = DataTable()
     for item in dataset.values():
         sentence = item['sentence']
         frames = item['frames']
         elements = item['elements']
         input_ids, attention_mask, head_indexes, frame_id, element_id, label_mask = process(
             sentence, frames, elements, self.tokenizer,
             self.frame_vocabulary, self.element_vocabulary,
             self.max_length)
         datable('input_ids', input_ids)
         datable('attention_mask', attention_mask)
         datable('head_indexes', head_indexes)
         datable('frame_id', frame_id)
         datable('element_id', element_id)
         datable('label_mask', label_mask)
     return datable
Ejemplo n.º 17
0
 def _load(self, path):
     dataset = load_json(path)
     datable = DataTable()
     for data in dataset:
         token = data['token']
         relation = data['relation']
         subj_start = data['subj_start']
         subj_end = data['subj_end']
         obj_start = data['obj_start']
         obj_end = data['obj_end']
         self.label_set.add(relation)
         datable('token', token)
         datable('relation', relation)
         datable('subj_start', subj_start)
         datable('subj_end', subj_end)
         datable('obj_start', obj_start)
         datable('obj_end', obj_end)
     return datable
Ejemplo n.º 18
0
 def process(self, dataset):
     datable = DataTable()
     for i in range(len(dataset)):
         sentence, label, frame, pos = dataset[i]
         input_id, attention_mask, segment_id, head_index, label_id, label_mask = process(sentence, label,
                                                                                          frame, pos,
                                                                                          self.tokenizer,
                                                                                          self.trigger_vocabulary,
                                                                                          self.argument_vocabulary,
                                                                                          self.max_length)
         datable('input_ids', input_id)
         datable('attention_mask', attention_mask)
         datable('segment_ids', segment_id)
         datable('head_indexes', head_index)
         datable('label_ids', label_id)
         datable('label_masks', label_mask)
         datable('frame', self.trigger_vocabulary.to_index(frame))
         datable('pos', pos)
     return datable
Ejemplo n.º 19
0
    def process(self, dataset):
        datable = DataTable()

        for item in tqdm(dataset, desc='Processing Data'):
            words = item['words']
            item_entity_mentions = item['entity_mentions']
            item_relation_mentions = item['relation_mentions']
            input_ids, attention_mask, head_indexes, entity_mentions, relation_mentions, entity_mentions_mask, relation_mentions_mask = process(
                words, item_entity_mentions, item_relation_mentions,
                self.tokenizer, self.vocabulary, self.max_length)
            if len(input_ids) <= self.max_length and len(head_indexes) <= self.max_length\
                    and len(entity_mentions) <= self.max_length and len(relation_mentions) <= self.max_length:
                datable('input_ids', input_ids)
                datable('attention_mask', attention_mask)
                datable('head_indexes', head_indexes)
                datable('entity_mentions', entity_mentions)
                datable('relation_mentions', relation_mentions)
                datable('entity_mentions_mask', entity_mentions_mask)
                datable('relation_mentions_mask', relation_mentions_mask)

        return datable
Ejemplo n.º 20
0
    def process(self, dataset):
        datable = DataTable()
        # add your own process code here
        for sentence, label in zip(dataset['sentence'], dataset['label']):
            bert_inputs,attention_masks,\
            grid_labels, grid_mask2d, \
            pieces2word, dist_inputs, \
            sent_length, entity_text = \
                process_w2ner(sentence,label,self.tokenizer,self.vocabulary,self.max_length)

            datable('bert_inputs', bert_inputs)
            datable('attention_masks', attention_masks)
            datable('grid_labels', grid_labels)
            datable('grid_mask2d', grid_mask2d)
            datable('pieces2word', pieces2word)
            datable('dist_inputs', dist_inputs)
            datable('sent_length', sent_length)
            # datable('entity_text', entity_text)
            # 暂时不用text信息 不方便拼接成batch

        return datable
Ejemplo n.º 21
0
 def process(self, dataset):
     cnt = 0
     datable = DataTable()
     for i in range(len(dataset)):
         words, mentions = dataset[i]
         for mention in mentions:
             labels = mention['labels']
             input_id, attention_mask, segment_id, head_index, label_id = \
                 process(words, labels, self.tokenizer, self.vocabulary, self.max_length)
             if len(input_id) <= self.max_length and len(head_index) <= self.max_length:
                 datable('input_id', input_id)
                 datable('attention_mask', attention_mask)
                 datable('segment_id', segment_id)
                 datable('head_index', head_index)
                 datable('label_id', label_id)
                 datable('start', mention['start'])
                 datable('end', mention['end'])
             else:
                 cnt += 1
     print(cnt)
     return datable
Ejemplo n.º 22
0
 def process(self, dataset):
     datable = DataTable()
     for words,lemmas,node_types,node_attrs,origin_lexical_units,p2p_edges,p2r_edges,origin_frames,frame_elements in \
             tqdm(zip(dataset["words"],dataset["lemma"],dataset["node_types"],
             dataset["node_attrs"],dataset["origin_lexical_units"],dataset["p2p_edges"],
             dataset["p2r_edges"],dataset["origin_frames"],dataset["frame_elements"]),total=len(dataset['words'])):
         tokens_x,token_masks,head_indexes,spans,\
         node_type_labels_list,node_attr_labels_list,\
         node_valid_attrs_list,valid_p2r_edges_list,\
         p2p_edge_labels_and_indices,p2r_edge_labels_and_indices,raw_words_len,n_spans = self.process_item(words,lemmas,node_types,node_attrs,origin_lexical_units,p2p_edges,p2r_edges,origin_frames,frame_elements )
         datable("tokens_x", tokens_x)
         datable("token_masks",token_masks)
         datable("head_indexes",head_indexes)
         datable("spans",spans )
         datable("node_type_labels_list",node_type_labels_list )#节点粗粒度分类
         datable("node_attr_labels_list",node_attr_labels_list )#节点细粒度分类
         datable("node_valid_attrs_list",node_valid_attrs_list)
         datable("valid_p2r_edges_list", valid_p2r_edges_list)
         datable("p2p_edge_labels_and_indices", p2p_edge_labels_and_indices)
         datable("p2r_edge_labels_and_indices", p2r_edge_labels_and_indices)
         datable("raw_words_len", raw_words_len)
         datable("n_spans",n_spans )
     return datable
Ejemplo n.º 23
0
 def _load(self, path):
     dataset = DataTable()
     with open(path) as f:
         lines = f.readlines()
         for line in lines:
             sample = json.loads(line)
             dataset("words", sample["sentence"])
             dataset("lemma", sample["lemmas"])
             dataset("node_types", sample["node_types"])
             dataset("node_attrs", sample["node_attrs"])
             dataset("origin_lexical_units", sample["origin_lexical_units"])
             dataset("p2p_edges", sample["p2p_edges"])
             dataset("p2r_edges", sample["p2r_edges"])
             dataset("origin_frames", sample["origin_frames"])
             dataset("frame_elements", sample["frame_elements"])
             for item in sample["node_types"]:
                 self.node_types_set.add(item[1])
             for item in sample["node_attrs"]:
                 self.node_attrs_set.add(item[1])
             for item in sample["p2p_edges"]:
                 self.p2p_edges_set.add(item[-1])
             for item in sample["p2r_edges"]:
                 self.p2r_edges_set.add(item[-1])
     return dataset
Ejemplo n.º 24
0
 def _load(self, path):
     dataset = DataTable()
     sentence = []
     label = []
     with open(path) as f:
         for line in f:
             if len(line) == 0 or line[0] == "\n":
                 if len(sentence) > 0:
                     dataset('sentence', sentence)
                     dataset('label', label)
                     sentence = []
                     label = []
                 continue
             line = line.strip()
             words = line.split(' ')
             sentence.append(words[0])
             label.append(words[-1])
             self.label_set.add(words[-1])
         if len(sentence) > 0:
             dataset('sentence', sentence)
             dataset('label', label)
         if len(dataset) == 0:
             raise RuntimeError("No data found {}.".format(path))
     return dataset
Ejemplo n.º 25
0
    def process_test(self, dataset):
        datable = DataTable()
        for content, index, type, args, occur, triggers, id in \
                tqdm(zip(dataset["content"], dataset["index"], dataset["type"],
                         dataset["args"], dataset["occur"], dataset["triggers"], dataset["id"]),
                     total=len(dataset["content"])):
            tokens_id, is_heads, head_indexes = [], [], []
            words = ['[CLS]'] + content + ['[SEP]']
            for w in words:
                tokens = self.tokenizer.tokenize(w) if w not in [
                    '[CLS]', '[SEP]'
                ] else [w]
                tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens)
                # if w in ['[CLS]', '[SEP]']:
                #     is_head = [0]
                # else:
                is_head = [1] + [0] * (len(tokens) - 1)
                tokens_id.extend(tokens_w_id)
                is_heads.extend(is_head)
            token_masks = [True] * len(tokens_id) + [False] * (
                self.max_length - len(tokens_id))
            token_masks = token_masks[:self.max_length]
            tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id))
            tokens_id = tokens_id[:self.max_length]
            is_heads = is_heads[:self.max_length]
            for i in range(len(is_heads)):
                if is_heads[i]:
                    head_indexes.append(i)
            head_indexes = head_indexes + [0] * (self.max_length -
                                                 len(head_indexes))
            head_indexes = head_indexes[:self.max_length]

            type_vec = np.array([0] * self.trigger_type_num)
            type_id = -1
            if type != "<unk>":
                type_id = self.trigger_vocabulary.word2idx[type]
                for occ in occur:
                    idx = self.trigger_vocabulary.word2idx[occ]
                    type_vec[idx] = 1

            t_m = [0] * self.max_length
            r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list(
                range(1, self.max_length - 0))
            r_pos = [p + self.max_length for p in r_pos]
            if index is not None:
                span = triggers[index]
                self.trigger_max_span_len[type] = max(
                    self.trigger_max_span_len[type], span[1] - span[0])
                start_idx = span[0] + 1
                end_idx = span[1] + 1 - 1
                r_pos = list(range(
                    -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list(
                        range(1, self.max_length - end_idx))
                r_pos = [p + self.max_length for p in r_pos]
                t_m = [0] * self.max_length
                t_m[start_idx] = 1
                t_m[end_idx] = 1

            t_index = index

            triggers_truth = [(span[0] + 1, span[1] + 1 - 1)
                              for span in triggers]  # 触发词起止列表改成左闭右闭
            args_truth = {i: [] for i in range(self.argument_type_num)}
            for args_name in args:
                s_r_i = self.argument_vocabulary.word2idx[args_name]
                # s_r_i = self.args_s_id[args_name + '_s']
                for span in args[args_name]:
                    args_truth[s_r_i].append((span[0] + 1, span[1] + 1 - 1))
            if type_id != -1:
                datable("data_ids", id)
                datable("type_id", type_id)
                datable("type_vec", type_vec)
                datable("tokens_id", tokens_id)
                datable("token_masks", token_masks)
                datable("t_index", t_index)
                datable("r_pos", r_pos)
                datable("t_m", t_m)
                datable("triggers_truth", triggers_truth)
                datable("args_truth", args_truth)
                datable("head_indexes", head_indexes)
                datable("content", content)
        return datable
Ejemplo n.º 26
0
    def process_train(self, dataset):
        datable = DataTable()
        for content, index, type, args, occur, triggers, id in \
            tqdm(zip(dataset["content"], dataset["index"], dataset["type"],
                     dataset["args"], dataset["occur"], dataset["triggers"],dataset["id"]),total=len(dataset["content"])):
            tokens_id, is_heads, head_indexes = [], [], []
            # content = list(map(lambda x: str(x), content))
            # words = ['[CLS]'] +content + ['[SEP]']
            # for w in words:
            #     tokens = self.tokenizer.tokenize(w) if w not in ['[CLS]', '[SEP]'] else [w]
            #     tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens)
            #     # if w in ['[CLS]', '[SEP]']:
            #     #     is_head = [0]
            #     # else:
            #     is_head = [1] + [0] * (len(tokens) - 1)
            #     tokens_id.extend(tokens_w_id)
            #     is_heads.extend(is_head)
            # token_masks = [True] * len(tokens_id) + [False] * (self.max_length - len(tokens_id))
            # token_masks=token_masks[: self.max_length]
            # tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id))
            # tokens_id=tokens_id[: self.max_length]
            # is_heads=is_heads[: self.max_length]
            # for i in range(len(is_heads)):
            #     if is_heads[i]:
            #         head_indexes.append(i)
            # head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes))
            # head_indexes=head_indexes[: self.max_length]
            data_content = [token.lower()
                            for token in content]  # 字符串遍历是一次取一个字,把字放在列表里面
            data_content = list(data_content)  # 再把这个列表强制类型转换一下,继续变成列表
            inputs = self.tokenizer.encode_plus(data_content,
                                                add_special_tokens=True,
                                                max_length=self.max_length,
                                                truncation=True,
                                                padding='max_length')
            tokens_id, segs, token_masks = inputs["input_ids"], inputs[
                "token_type_ids"], inputs['attention_mask']
            head_indexes = list(np.arange(0, sum(token_masks)))
            head_indexes = head_indexes + [0] * (self.max_length -
                                                 len(head_indexes))
            head_indexes = head_indexes[:self.max_length]

            type_vec = np.array([0] * self.trigger_type_num)
            type_id = -1
            if type != "<unk>":
                type_id = self.trigger_vocabulary.word2idx[type]
                for occ in occur:
                    idx = self.trigger_vocabulary.word2idx[occ]
                    type_vec[idx] = 1

            t_m = [0] * self.max_length
            r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list(
                range(1, self.max_length - 0))
            r_pos = [p + self.max_length for p in r_pos]
            if index is not None:
                span = triggers[index]
                self.trigger_max_span_len[type] = max(
                    self.trigger_max_span_len[type], span[1] - span[0])
                start_idx = span[0] + 1
                end_idx = span[1] + 1 - 1
                r_pos = list(range(
                    -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list(
                        range(1, self.max_length - end_idx))
                r_pos = [p + self.max_length for p in r_pos]
                t_m = [0] * self.max_length
                t_m[start_idx] = 1
                t_m[end_idx] = 1

            t_index = index

            ##

            t_s = [0] * self.max_length
            t_e = [0] * self.max_length

            for t in triggers:
                t_s[t[0] + 1] = 1
                t_e[t[1] + 1 - 1] = 1

            args_s = np.zeros(shape=[self.argument_type_num, self.max_length])
            args_e = np.zeros(shape=[self.argument_type_num, self.max_length])
            arg_mask = [0] * self.argument_type_num
            for args_name in args:
                s_r_i = self.argument_vocabulary.word2idx[args_name]
                e_r_i = self.argument_vocabulary.word2idx[args_name]
                # s_r_i = self.args_s_id[args_name + '_s']
                # e_r_i = self.args_e_id[args_name + '_e']
                arg_mask[s_r_i] = 1
                for span in args[args_name]:
                    self.argument_max_span_len[args_name] = max(
                        span[1] - span[0],
                        self.argument_max_span_len[args_name])
                    args_s[s_r_i][span[0] + 1] = 1
                    args_e[e_r_i][span[1] + 1 - 1] = 1

            if type_id != -1:
                datable("data_ids", id)
                datable("tokens_id", tokens_id)
                datable("token_masks", token_masks)
                datable("head_indexes", head_indexes)
                datable("type_id", type_id)
                datable("type_vec", type_vec)
                datable("r_pos", r_pos)
                datable("t_m", t_m)
                datable("t_index", t_index)
                datable("t_s", t_s)
                datable("t_e", t_e)
                datable("a_s", args_s)
                datable("a_e", args_e)
                datable("a_m", arg_mask)
                datable("content", content)

        return datable
Ejemplo n.º 27
0
    def process(self, dataset, path=None):
        datable = DataTable()
        for item in dataset:
            words = item['words']
            triggers = ['O'] * len(words)
            arguments = {
                'candidates': [
                    # ex. (5, 6, "entity_type_str"), ...
                ],
                'events': {
                    # ex. (1, 3, "trigger_type_str"): [(5, 6, "argument_role_idx"), ...]
                },
            }

            for entity_mention in item['golden-entity-mentions']:
                arguments['candidates'].append(
                    (entity_mention['start'], entity_mention['end'],
                     entity_mention['entity-type']))

            for event_mention in item['golden-event-mentions']:
                for i in range(event_mention['trigger']['start'],
                               event_mention['trigger']['end']):
                    trigger_type = event_mention['event_type']
                    if i == event_mention['trigger']['start']:
                        triggers[i] = 'B-{}'.format(trigger_type)
                    else:
                        triggers[i] = 'I-{}'.format(trigger_type)

                event_key = (event_mention['trigger']['start'],
                             event_mention['trigger']['end'],
                             event_mention['event_type'])
                arguments['events'][event_key] = []
                for argument in event_mention['arguments']:
                    role = argument['role']
                    if role.startswith('Time'):
                        role = role.split('-')[0]
                    arguments['events'][event_key].append(
                        (argument['start'], argument['end'],
                         self.argument_vocabulary.to_index(role)))

            words = ['[CLS]'] + words + ['[SEP]']
            tokens_x, triggers_y, arguments, head_indexes, words, triggers = \
                process(words, triggers, arguments, self.tokenizer, self.trigger_vocabulary, self.max_length)

            datable('tokens_x', tokens_x)
            datable('triggers_y', triggers_y)
            # arguments = {
            #     'candidates': [
            #         # ex. (5, 6, "entity_type_str"), ...
            #     ],
            #     'events': {
            #         # ex. (1, 3, "trigger_type_str"): [(5, 6, "argument_role_idx"), ...]
            #     },
            # }
            datable('arguments', arguments)
            datable('head_indexes', head_indexes)
            datable('words', words)
            datable('triggers', triggers)

        if path and os.path.exists(path):
            datable.save_table(path)
        return datable