class Processor: def __init__(self, label_list=None, path=None, padding='<pad>', unknown='<unk>', bert_model='bert-base-cased', max_length=256): self.path = path self.max_length = max_length self.bert_model = bert_model self.tokenizer = BertTokenizer.from_pretrained(self.bert_model) if label_list: self.vocabulary = Vocabulary(padding=padding, unknown=unknown) self.vocabulary.add_word_lst(label_list) self.vocabulary.build_vocab() self.save_vocabulary(self.path) else: self.load_vocabulary(self.path) def set_vocabulary(self, vocabulary): self.vocabulary = vocabulary def get_vocabulary(self): return self.vocabulary def save_vocabulary(self, path): self.vocabulary.save(os.path.join(path, 'vocabulary.txt')) def load_vocabulary(self, path): self.vocabulary = Vocabulary.load(os.path.join(path, 'vocabulary.txt')) def load(self): pass
class FrameArgumentProcessor(Processor): def __init__(self, label_list=None, path=None, padding=None, unknown=None, bert_model='bert-base-cased', max_length=256, trigger_label_list=None, argument_label_list=None): super().__init__(label_list, path, padding=padding, unknown=unknown, bert_model=bert_model, max_length=max_length) self.trigger_vocabulary = Vocabulary(padding=padding) self.trigger_vocabulary.add_word_lst(trigger_label_list) self.trigger_vocabulary.build_vocab() self.argument_vocabulary = Vocabulary(padding=padding, unknown=unknown) self.argument_vocabulary.add_word_lst(argument_label_list) self.argument_vocabulary.build_vocab() def process(self, dataset): datable = DataTable() for i in range(len(dataset)): sentence, label, frame, pos = dataset[i] input_id, attention_mask, segment_id, head_index, label_id, label_mask = process(sentence, label, frame, pos, self.tokenizer, self.trigger_vocabulary, self.argument_vocabulary, self.max_length) datable('input_ids', input_id) datable('attention_mask', attention_mask) datable('segment_ids', segment_id) datable('head_indexes', head_index) datable('label_ids', label_id) datable('label_masks', label_mask) datable('frame', self.trigger_vocabulary.to_index(frame)) datable('pos', pos) return datable
class ACE2005CASEEProcessor: def __init__(self, schema_path=None, trigger_path=None, argument_path=None, bert_model='bert-base-cased', max_length=128): self.schema_path = schema_path self.trigger_path = trigger_path self.argument_path = argument_path self.bert_model = bert_model self.max_length = max_length self.tokenizer = BertTokenizer.from_pretrained(self.bert_model) with open(self.schema_path, 'r', encoding='utf-8') as f: self.schema_str = json.load(f) self.trigger_type_list = list() self.argument_type_list = list() trigger_type_set = set() argument_type_set = set() for trigger_type, argument_type_list in self.schema_str.items(): trigger_type_set.add(trigger_type) for argument_type in argument_type_list: argument_type_set.add(argument_type) self.trigger_type_list = list(trigger_type_set) self.argument_type_list = list(argument_type_set) self.args_s_id = {} self.args_e_id = {} for i in range(len(self.argument_type_list)): s = self.argument_type_list[i] + '_s' self.args_s_id[s] = i e = self.argument_type_list[i] + '_e' self.args_e_id[e] = i if os.path.exists(self.trigger_path): self.trigger_vocabulary = Vocabulary.load(self.trigger_path) else: self.trigger_vocabulary = Vocabulary(padding=None, unknown=None) self.trigger_vocabulary.add_word_lst(self.trigger_type_list) self.trigger_vocabulary.build_vocab() self.trigger_vocabulary.save(self.trigger_path) if os.path.exists(self.argument_path): self.argument_vocabulary = Vocabulary.load(self.argument_path) else: self.argument_vocabulary = Vocabulary(padding=None, unknown=None) self.argument_vocabulary.add_word_lst(self.argument_type_list) self.argument_vocabulary.build_vocab() self.argument_vocabulary.save(self.argument_path) self.schema_id = {} for trigger_type, argument_type_list in self.schema_str.items(): self.schema_id[self.trigger_vocabulary.word2idx[trigger_type]] = [ self.argument_vocabulary.word2idx[a] for a in argument_type_list ] self.trigger_type_num = len(self.trigger_vocabulary) self.argument_type_num = len(self.argument_vocabulary) self.trigger_max_span_len = {} self.argument_max_span_len = {} for name in self.trigger_vocabulary.word2idx: self.trigger_max_span_len[name] = 1 for name in self.argument_vocabulary.word2idx: self.argument_max_span_len[name] = 1 def get_trigger_max_span_len(self): return self.trigger_max_span_len def get_argument_max_span_len(self): return self.argument_max_span_len def process_train(self, dataset): datable = DataTable() for content, index, type, args, occur, triggers, id in \ tqdm(zip(dataset["content"], dataset["index"], dataset["type"], dataset["args"], dataset["occur"], dataset["triggers"],dataset["id"]),total=len(dataset["content"])): tokens_id, is_heads, head_indexes = [], [], [] words = ['[CLS]'] + content + ['[SEP]'] for w in words: tokens = self.tokenizer.tokenize(w) if w not in [ '[CLS]', '[SEP]' ] else [w] tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens) # if w in ['[CLS]', '[SEP]']: # is_head = [0] # else: is_head = [1] + [0] * (len(tokens) - 1) tokens_id.extend(tokens_w_id) is_heads.extend(is_head) token_masks = [True] * len(tokens_id) + [False] * ( self.max_length - len(tokens_id)) token_masks = token_masks[:self.max_length] tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id)) tokens_id = tokens_id[:self.max_length] is_heads = is_heads[:self.max_length] for i in range(len(is_heads)): if is_heads[i]: head_indexes.append(i) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) head_indexes = head_indexes[:self.max_length] type_vec = np.array([0] * self.trigger_type_num) type_id = -1 if type != "<unk>": type_id = self.trigger_vocabulary.word2idx[type] for occ in occur: idx = self.trigger_vocabulary.word2idx[occ] type_vec[idx] = 1 t_m = [0] * self.max_length r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list( range(1, self.max_length - 0)) r_pos = [p + self.max_length for p in r_pos] if index is not None: span = triggers[index] self.trigger_max_span_len[type] = max( self.trigger_max_span_len[type], span[1] - span[0]) start_idx = span[0] + 1 end_idx = span[1] + 1 - 1 r_pos = list(range( -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list( range(1, self.max_length - end_idx)) r_pos = [p + self.max_length for p in r_pos] t_m = [0] * self.max_length t_m[start_idx] = 1 t_m[end_idx] = 1 t_index = index ## t_s = [0] * self.max_length t_e = [0] * self.max_length for t in triggers: t_s[t[0] + 1] = 1 t_e[t[1] + 1 - 1] = 1 args_s = np.zeros(shape=[self.argument_type_num, self.max_length]) args_e = np.zeros(shape=[self.argument_type_num, self.max_length]) arg_mask = [0] * self.argument_type_num for args_name in args: s_r_i = self.argument_vocabulary.word2idx[args_name] e_r_i = self.argument_vocabulary.word2idx[args_name] # s_r_i = self.args_s_id[args_name + '_s'] # e_r_i = self.args_e_id[args_name + '_e'] arg_mask[s_r_i] = 1 for span in args[args_name]: self.argument_max_span_len[args_name] = max( span[1] - span[0], self.argument_max_span_len[args_name]) args_s[s_r_i][span[0] + 1] = 1 args_e[e_r_i][span[1] + 1 - 1] = 1 if type_id != -1: datable("data_ids", id) datable("tokens_id", tokens_id) datable("token_masks", token_masks) datable("head_indexes", head_indexes) datable("type_id", type_id) datable("type_vec", type_vec) datable("r_pos", r_pos) datable("t_m", t_m) datable("t_index", t_index) datable("t_s", t_s) datable("t_e", t_e) datable("a_s", args_s) datable("a_e", args_e) datable("a_m", arg_mask) datable("content", content) return datable def process_dev(self, dataset): datable = DataTable() for content, index, type, args, occur, triggers, id in \ tqdm(zip(dataset["content"], dataset["index"], dataset["type"], dataset["args"], dataset["occur"], dataset["triggers"], dataset["id"]), total=len(dataset["content"])): tokens_id, is_heads, head_indexes = [], [], [] words = ['[CLS]'] + content + ['[SEP]'] for w in words: tokens = self.tokenizer.tokenize(w) if w not in [ '[CLS]', '[SEP]' ] else [w] tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens) # if w in ['[CLS]', '[SEP]']: # is_head = [0] # else: is_head = [1] + [0] * (len(tokens) - 1) tokens_id.extend(tokens_w_id) is_heads.extend(is_head) token_masks = [True] * len(tokens_id) + [False] * ( self.max_length - len(tokens_id)) token_masks = token_masks[:self.max_length] tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id)) tokens_id = tokens_id[:self.max_length] is_heads = is_heads[:self.max_length] for i in range(len(is_heads)): if is_heads[i]: head_indexes.append(i) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) head_indexes = head_indexes[:self.max_length] type_vec = np.array([0] * self.trigger_type_num) type_id = -1 if type != "<unk>": type_id = self.trigger_vocabulary.word2idx[type] for occ in occur: idx = self.trigger_vocabulary.word2idx[occ] type_vec[idx] = 1 t_m = [0] * self.max_length r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list( range(1, self.max_length - 0)) r_pos = [p + self.max_length for p in r_pos] if index is not None: span = triggers[index] self.trigger_max_span_len[type] = max( self.trigger_max_span_len[type], span[1] - span[0]) start_idx = span[0] + 1 end_idx = span[1] + 1 - 1 r_pos = list(range( -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list( range(1, self.max_length - end_idx)) r_pos = [p + self.max_length for p in r_pos] t_m = [0] * self.max_length t_m[start_idx] = 1 t_m[end_idx] = 1 t_index = index triggers_truth = [(span[0] + 1, span[1] + 1 - 1) for span in triggers] # 触发词起止列表改成左闭右闭 args_truth = {i: [] for i in range(self.argument_type_num)} for args_name in args: s_r_i = self.argument_vocabulary.word2idx[args_name] # s_r_i = self.args_s_id[args_name + '_s'] for i, span in enumerate(args[args_name]): self.argument_max_span_len[args_name] = max( span[1] - span[0], self.argument_max_span_len[args_name]) args_truth[s_r_i].append((span[0] + 1, span[1] + 1 - 1)) if type_id != -1: datable("data_ids", id) datable("type_id", type_id) datable("type_vec", type_vec) datable("tokens_id", tokens_id) datable("token_masks", token_masks) datable("t_index", t_index) datable("r_pos", r_pos) datable("t_m", t_m) datable("triggers_truth", triggers_truth) datable("args_truth", args_truth) datable("head_indexes", head_indexes) datable("content", content) return datable def process_test(self, dataset): datable = DataTable() for content, index, type, args, occur, triggers, id in \ tqdm(zip(dataset["content"], dataset["index"], dataset["type"], dataset["args"], dataset["occur"], dataset["triggers"], dataset["id"]), total=len(dataset["content"])): tokens_id, is_heads, head_indexes = [], [], [] words = ['[CLS]'] + content + ['[SEP]'] for w in words: tokens = self.tokenizer.tokenize(w) if w not in [ '[CLS]', '[SEP]' ] else [w] tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens) # if w in ['[CLS]', '[SEP]']: # is_head = [0] # else: is_head = [1] + [0] * (len(tokens) - 1) tokens_id.extend(tokens_w_id) is_heads.extend(is_head) token_masks = [True] * len(tokens_id) + [False] * ( self.max_length - len(tokens_id)) token_masks = token_masks[:self.max_length] tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id)) tokens_id = tokens_id[:self.max_length] is_heads = is_heads[:self.max_length] for i in range(len(is_heads)): if is_heads[i]: head_indexes.append(i) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) head_indexes = head_indexes[:self.max_length] type_vec = np.array([0] * self.trigger_type_num) type_id = -1 if type != "<unk>": type_id = self.trigger_vocabulary.word2idx[type] for occ in occur: idx = self.trigger_vocabulary.word2idx[occ] type_vec[idx] = 1 t_m = [0] * self.max_length r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list( range(1, self.max_length - 0)) r_pos = [p + self.max_length for p in r_pos] if index is not None: span = triggers[index] self.trigger_max_span_len[type] = max( self.trigger_max_span_len[type], span[1] - span[0]) start_idx = span[0] + 1 end_idx = span[1] + 1 - 1 r_pos = list(range( -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list( range(1, self.max_length - end_idx)) r_pos = [p + self.max_length for p in r_pos] t_m = [0] * self.max_length t_m[start_idx] = 1 t_m[end_idx] = 1 t_index = index triggers_truth = [(span[0] + 1, span[1] + 1 - 1) for span in triggers] # 触发词起止列表改成左闭右闭 args_truth = {i: [] for i in range(self.argument_type_num)} for args_name in args: s_r_i = self.argument_vocabulary.word2idx[args_name] # s_r_i = self.args_s_id[args_name + '_s'] for span in args[args_name]: args_truth[s_r_i].append((span[0] + 1, span[1] + 1 - 1)) if type_id != -1: datable("data_ids", id) datable("type_id", type_id) datable("type_vec", type_vec) datable("tokens_id", tokens_id) datable("token_masks", token_masks) datable("t_index", t_index) datable("r_pos", r_pos) datable("t_m", t_m) datable("triggers_truth", triggers_truth) datable("args_truth", args_truth) datable("head_indexes", head_indexes) datable("content", content) return datable def get_trigger_vocabulary(self): return self.trigger_vocabulary def get_argument_vocabulary(self): return self.argument_vocabulary
class FrameNet4JointProcessor: def __init__(self, node_types_label_list=None, node_attrs_label_list=None, p2p_edges_label_list=None, p2r_edges_label_list=None, path=None,bert_model='bert-base-cased',max_span_width = 15, max_length=128): self.path = path self.bert_model = bert_model self.max_length = max_length self.tokenizer = BertTokenizer.from_pretrained(bert_model) self.max_span_width = max_span_width self._ontology = FrameOntology(self.path) if node_types_label_list: self.node_types_vocabulary = Vocabulary(padding="O", unknown=None) self.node_types_vocabulary.add_word_lst(node_types_label_list) self.node_types_vocabulary.build_vocab() self.node_types_vocabulary.save(os.path.join(path, 'node_types_vocabulary.txt')) else: self.node_types_vocabulary = Vocabulary.load(os.path.join(path, 'node_types_vocabulary.txt')) if node_attrs_label_list: self.node_attrs_vocabulary = Vocabulary(padding="O", unknown=None) self.node_attrs_vocabulary.add_word_lst(node_attrs_label_list) self.node_attrs_vocabulary.build_vocab() self.node_attrs_vocabulary.save(os.path.join(path, 'node_attrs_vocabulary.txt')) else: self.node_attrs_vocabulary = Vocabulary.load(os.path.join(path, 'node_attrs_vocabulary.txt')) if p2p_edges_label_list: self.p2p_edges_vocabulary = Vocabulary(padding=None, unknown=None) self.p2p_edges_vocabulary.add_word_lst(p2p_edges_label_list) self.p2p_edges_vocabulary.build_vocab() self.p2p_edges_vocabulary.save(os.path.join(path, 'p2p_edges_vocabulary.txt')) else: self.p2p_edges_vocabulary = Vocabulary.load(os.path.join(path, 'p2p_edges_vocabulary.txt')) if p2r_edges_label_list: self.p2r_edges_vocabulary = Vocabulary(padding=None, unknown=None) self.p2r_edges_vocabulary.add_word_lst(p2r_edges_label_list) self.p2r_edges_vocabulary.build_vocab() self.p2r_edges_vocabulary.save(os.path.join(path, 'p2r_edges_vocabulary.txt')) else: self.p2r_edges_vocabulary = Vocabulary.load(os.path.join(path, 'p2r_edges_vocabulary.txt')) def get_node_types_vocabulary(self): return self.node_types_vocabulary def get_node_attrs_vocabulary(self): return self.node_attrs_vocabulary def get_p2p_edges_vocabulary(self): return self.p2p_edges_vocabulary def get_p2r_edges_vocabulary(self): return self.p2r_edges_vocabulary def process(self, dataset): datable = DataTable() for words,lemmas,node_types,node_attrs,origin_lexical_units,p2p_edges,p2r_edges,origin_frames,frame_elements in \ tqdm(zip(dataset["words"],dataset["lemma"],dataset["node_types"], dataset["node_attrs"],dataset["origin_lexical_units"],dataset["p2p_edges"], dataset["p2r_edges"],dataset["origin_frames"],dataset["frame_elements"]),total=len(dataset['words'])): tokens_x,token_masks,head_indexes,spans,\ node_type_labels_list,node_attr_labels_list,\ node_valid_attrs_list,valid_p2r_edges_list,\ p2p_edge_labels_and_indices,p2r_edge_labels_and_indices,raw_words_len,n_spans = self.process_item(words,lemmas,node_types,node_attrs,origin_lexical_units,p2p_edges,p2r_edges,origin_frames,frame_elements ) datable("tokens_x", tokens_x) datable("token_masks",token_masks) datable("head_indexes",head_indexes) datable("spans",spans ) datable("node_type_labels_list",node_type_labels_list )#节点粗粒度分类 datable("node_attr_labels_list",node_attr_labels_list )#节点细粒度分类 datable("node_valid_attrs_list",node_valid_attrs_list) datable("valid_p2r_edges_list", valid_p2r_edges_list) datable("p2p_edge_labels_and_indices", p2p_edge_labels_and_indices) datable("p2r_edge_labels_and_indices", p2r_edge_labels_and_indices) datable("raw_words_len", raw_words_len) datable("n_spans",n_spans ) return datable def process_item(self,raw_words,lemmas,node_types,node_attrs,origin_lexical_units,p2p_edges,p2r_edges,origin_frames,frame_elements ): #process token tokens_x, is_heads,head_indexes = [],[],[] raw_words_len = len(raw_words) words = ['[CLS]'] + raw_words + ['[SEP]'] for w in words: tokens = self.tokenizer.tokenize(w) if w not in ['[CLS]', '[SEP]'] else [w] tokens_xx = self.tokenizer.convert_tokens_to_ids(tokens) if w in ['[CLS]', '[SEP]']: is_head = [0] else: is_head = [1] + [0] * (len(tokens) - 1) tokens_x.extend(tokens_xx) is_heads.extend(is_head) token_masks = [True]*len(tokens_x) + [False] * (self.max_length - len(tokens_x)) tokens_x = tokens_x + [0] * (self.max_length - len(tokens_x)) for i in range(len(is_heads)): if is_heads[i]: head_indexes.append(i) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) #process other data node_types_dict, node_attrs_dict, origin_lus_dict, \ p2p_edges_dict, p2r_edges_dict, origin_frames_dict, frame_elements_dict = \ format_label_fields(node_types, node_attrs, origin_lexical_units,p2p_edges, p2r_edges, origin_frames, frame_elements) #process span and node node_valid_attrs_list= [] # use for the comprehensive vocabulary valid_p2r_edges_list= [] node_type_labels_list=[] node_attr_labels_list=[] spans=self.get_spans(raw_words,max_span_width=self.max_span_width) for start, end in spans: span_ix = (start, end) node_type_label = node_types_dict[span_ix] node_attr_label = node_attrs_dict[span_ix] node_type_labels_list.append(node_type_label) node_attr_labels_list.append(node_attr_label) lexical_unit = origin_lus_dict[span_ix] if lexical_unit in self._ontology.lu_frame_map: valid_attrs = self._ontology.lu_frame_map[lexical_unit] else: valid_attrs = ["O"] node_valid_attrs_list.append( [x for x in valid_attrs]) if node_attr_label in self._ontology.frame_fe_map: valid_p2r_edge_labels = self._ontology.frame_fe_map[node_attr_label] valid_p2r_edges_list.append([x for x in valid_p2r_edge_labels]) else: valid_p2r_edges_list.append([-1]) #process edge n_spans = len(spans) span_tuples = [(span[0], span[1]) for span in spans] candidate_indices = [(i, j) for i in range(n_spans) for j in range(n_spans)] p2p_edge_labels = [] p2p_edge_indices = [] p2p_edge_labels_and_indices={} p2r_edge_labels = [] p2r_edge_indices = [] p2r_edge_labels_and_indices = {} for i, j in candidate_indices: # becasue i index is nested, j is not nested span_pair = (span_tuples[i], span_tuples[j]) p2p_edge_label = p2p_edges_dict[span_pair] p2r_edge_label = p2r_edges_dict[span_pair] if p2p_edge_label: p2p_edge_indices.append((i, j)) p2p_edge_labels.append(p2p_edge_label) if p2r_edge_label: p2r_edge_indices.append((i, j)) p2r_edge_labels.append(p2r_edge_label) p2p_edge_labels_and_indices["indices"] = p2p_edge_indices p2p_edge_labels_and_indices["labels"] = p2p_edge_labels p2r_edge_labels_and_indices["indices"] = p2r_edge_indices p2r_edge_labels_and_indices["labels"] = p2r_edge_labels return tokens_x,token_masks,head_indexes,spans,node_type_labels_list,node_attr_labels_list,node_valid_attrs_list,valid_p2r_edges_list,p2p_edge_labels_and_indices,p2r_edge_labels_and_indices,raw_words_len,n_spans def get_spans(self,tokens,min_span_width=1 ,max_span_width=None, filter_function= None): max_span_width = max_span_width or len(tokens) filter_function = filter_function or (lambda x: True) spans= [] for start_index in range(len(tokens)): last_end_index = min(start_index + max_span_width, len(tokens)) first_end_index = min(start_index + min_span_width - 1, len(tokens)) for end_index in range(first_end_index, last_end_index): start = start_index end = end_index if filter_function(tokens[slice(start_index, end_index + 1)]): spans.append((start, end)) return spans
class FINANCECASEEProcessor: def __init__(self, schema_path=None, trigger_path=None, argument_path=None, bert_model='bert-base-chinese', max_length=128): self.schema_path = schema_path self.trigger_path = trigger_path self.argument_path = argument_path self.bert_model = bert_model self.max_length = max_length self.tokenizer = BertTokenizer.from_pretrained(self.bert_model) with open(self.schema_path, 'r', encoding='utf-8') as f: self.schema_str = json.load(f) self.trigger_type_list = list() self.argument_type_list = list() trigger_type_set = set() argument_type_set = set() for trigger_type, argument_type_list in self.schema_str.items(): trigger_type_set.add(trigger_type) for argument_type in argument_type_list: argument_type_set.add(argument_type) self.trigger_type_list = list(trigger_type_set) self.argument_type_list = list(argument_type_set) self.args_s_id = {} self.args_e_id = {} for i in range(len(self.argument_type_list)): s = self.argument_type_list[i] + '_s' self.args_s_id[s] = i e = self.argument_type_list[i] + '_e' self.args_e_id[e] = i # if os.path.exists(self.trigger_path): # self.trigger_vocabulary = Vocabulary.load(self.trigger_path) # else: self.trigger_vocabulary = Vocabulary(padding=None, unknown=None) self.trigger_vocabulary.add_word_lst( ['质押', '股份股权转让', '投资', '减持', '起诉', '收购', '判决', '签署合同', '担保', '中标']) self.trigger_vocabulary.build_vocab() self.trigger_vocabulary.save(self.trigger_path) # if os.path.exists(self.argument_path): # self.argument_vocabulary = Vocabulary.load(self.argument_path) # else: self.argument_vocabulary = Vocabulary(padding=None, unknown=None) self.argument_vocabulary.add_word_lst([ 'collateral', 'obj-per', 'sub-per', 'sub-org', 'share-per', 'title', 'way', 'money', 'obj-org', 'number', 'amount', 'proportion', 'target-company', 'date', 'sub', 'share-org', 'obj', 'institution' ]) self.argument_vocabulary.build_vocab() self.argument_vocabulary.save(self.argument_path) self.schema_id = {} for trigger_type, argument_type_list in self.schema_str.items(): self.schema_id[self.trigger_vocabulary.word2idx[trigger_type]] = [ self.argument_vocabulary.word2idx[a] for a in argument_type_list ] self.trigger_type_num = len(self.trigger_vocabulary) self.argument_type_num = len(self.argument_vocabulary) self.trigger_max_span_len = {} self.argument_max_span_len = {} for name in self.trigger_vocabulary.word2idx: self.trigger_max_span_len[name] = 1 for name in self.argument_vocabulary.word2idx: self.argument_max_span_len[name] = 1 def get_trigger_max_span_len(self): return self.trigger_max_span_len def get_argument_max_span_len(self): return self.argument_max_span_len def process_train(self, dataset): datable = DataTable() for content, index, type, args, occur, triggers, id in \ tqdm(zip(dataset["content"], dataset["index"], dataset["type"], dataset["args"], dataset["occur"], dataset["triggers"],dataset["id"]),total=len(dataset["content"])): tokens_id, is_heads, head_indexes = [], [], [] # content = list(map(lambda x: str(x), content)) # words = ['[CLS]'] +content + ['[SEP]'] # for w in words: # tokens = self.tokenizer.tokenize(w) if w not in ['[CLS]', '[SEP]'] else [w] # tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens) # # if w in ['[CLS]', '[SEP]']: # # is_head = [0] # # else: # is_head = [1] + [0] * (len(tokens) - 1) # tokens_id.extend(tokens_w_id) # is_heads.extend(is_head) # token_masks = [True] * len(tokens_id) + [False] * (self.max_length - len(tokens_id)) # token_masks=token_masks[: self.max_length] # tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id)) # tokens_id=tokens_id[: self.max_length] # is_heads=is_heads[: self.max_length] # for i in range(len(is_heads)): # if is_heads[i]: # head_indexes.append(i) # head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) # head_indexes=head_indexes[: self.max_length] data_content = [token.lower() for token in content] # 字符串遍历是一次取一个字,把字放在列表里面 data_content = list(data_content) # 再把这个列表强制类型转换一下,继续变成列表 inputs = self.tokenizer.encode_plus(data_content, add_special_tokens=True, max_length=self.max_length, truncation=True, padding='max_length') tokens_id, segs, token_masks = inputs["input_ids"], inputs[ "token_type_ids"], inputs['attention_mask'] head_indexes = list(np.arange(0, sum(token_masks))) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) head_indexes = head_indexes[:self.max_length] type_vec = np.array([0] * self.trigger_type_num) type_id = -1 if type != "<unk>": type_id = self.trigger_vocabulary.word2idx[type] for occ in occur: idx = self.trigger_vocabulary.word2idx[occ] type_vec[idx] = 1 t_m = [0] * self.max_length r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list( range(1, self.max_length - 0)) r_pos = [p + self.max_length for p in r_pos] if index is not None: span = triggers[index] self.trigger_max_span_len[type] = max( self.trigger_max_span_len[type], span[1] - span[0]) start_idx = span[0] + 1 end_idx = span[1] + 1 - 1 r_pos = list(range( -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list( range(1, self.max_length - end_idx)) r_pos = [p + self.max_length for p in r_pos] t_m = [0] * self.max_length t_m[start_idx] = 1 t_m[end_idx] = 1 t_index = index ## t_s = [0] * self.max_length t_e = [0] * self.max_length for t in triggers: t_s[t[0] + 1] = 1 t_e[t[1] + 1 - 1] = 1 args_s = np.zeros(shape=[self.argument_type_num, self.max_length]) args_e = np.zeros(shape=[self.argument_type_num, self.max_length]) arg_mask = [0] * self.argument_type_num for args_name in args: s_r_i = self.argument_vocabulary.word2idx[args_name] e_r_i = self.argument_vocabulary.word2idx[args_name] # s_r_i = self.args_s_id[args_name + '_s'] # e_r_i = self.args_e_id[args_name + '_e'] arg_mask[s_r_i] = 1 for span in args[args_name]: self.argument_max_span_len[args_name] = max( span[1] - span[0], self.argument_max_span_len[args_name]) args_s[s_r_i][span[0] + 1] = 1 args_e[e_r_i][span[1] + 1 - 1] = 1 if type_id != -1: datable("data_ids", id) datable("tokens_id", tokens_id) datable("token_masks", token_masks) datable("head_indexes", head_indexes) datable("type_id", type_id) datable("type_vec", type_vec) datable("r_pos", r_pos) datable("t_m", t_m) datable("t_index", t_index) datable("t_s", t_s) datable("t_e", t_e) datable("a_s", args_s) datable("a_e", args_e) datable("a_m", arg_mask) datable("content", content) return datable def process_dev(self, dataset): datable = DataTable() for content, index, type, args, occur, triggers, id in \ tqdm(zip(dataset["content"], dataset["index"], dataset["type"], dataset["args"], dataset["occur"], dataset["triggers"], dataset["id"]), total=len(dataset["content"])): tokens_id, is_heads, head_indexes = [], [], [] # content = list(map(lambda x: str(x), content)) # words = ['[CLS]'] +content + ['[SEP]'] # for w in words: # tokens = self.tokenizer.tokenize(w) if w not in ['[CLS]', '[SEP]'] else [w] # tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens) # # if w in ['[CLS]', '[SEP]']: # # is_head = [0] # # else: # is_head = [1] + [0] * (len(tokens) - 1) # tokens_id.extend(tokens_w_id) # is_heads.extend(is_head) # token_masks = [True] * len(tokens_id) + [False] * (self.max_length - len(tokens_id)) # token_masks=token_masks[: self.max_length] # tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id)) # tokens_id=tokens_id[: self.max_length] # is_heads=is_heads[: self.max_length] # for i in range(len(is_heads)): # if is_heads[i]: # head_indexes.append(i) # head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) # head_indexes=head_indexes[: self.max_length] data_content = [token.lower() for token in content] # 字符串遍历是一次取一个字,把字放在列表里面 data_content = list(data_content) # 再把这个列表强制类型转换一下,继续变成列表 inputs = self.tokenizer.encode_plus(data_content, add_special_tokens=True, max_length=self.max_length, truncation=True, padding='max_length') tokens_id, segs, token_masks = inputs["input_ids"], inputs[ "token_type_ids"], inputs['attention_mask'] head_indexes = list(np.arange(0, sum(token_masks))) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) head_indexes = head_indexes[:self.max_length] type_vec = np.array([0] * self.trigger_type_num) type_id = -1 if type != "<unk>": type_id = self.trigger_vocabulary.word2idx[type] for occ in occur: idx = self.trigger_vocabulary.word2idx[occ] type_vec[idx] = 1 t_m = [0] * self.max_length r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list( range(1, self.max_length - 0)) r_pos = [p + self.max_length for p in r_pos] if index is not None: span = triggers[index] self.trigger_max_span_len[type] = max( self.trigger_max_span_len[type], span[1] - span[0]) start_idx = span[0] + 1 end_idx = span[1] + 1 - 1 r_pos = list(range( -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list( range(1, self.max_length - end_idx)) r_pos = [p + self.max_length for p in r_pos] t_m = [0] * self.max_length t_m[start_idx] = 1 t_m[end_idx] = 1 t_index = index triggers_truth = [(span[0] + 1, span[1] + 1 - 1) for span in triggers] # 触发词起止列表改成左闭右闭 args_truth = {i: [] for i in range(self.argument_type_num)} for args_name in args: s_r_i = self.argument_vocabulary.word2idx[args_name] # s_r_i = self.args_s_id[args_name + '_s'] for span in args[args_name]: self.argument_max_span_len[args_name] = max( span[1] - span[0], self.argument_max_span_len[args_name]) args_truth[s_r_i].append((span[0] + 1, span[1] + 1 - 1)) if type_id != -1: datable("data_ids", id) datable("type_id", type_id) datable("type_vec", type_vec) datable("tokens_id", tokens_id) datable("token_masks", token_masks) datable("t_index", t_index) datable("r_pos", r_pos) datable("t_m", t_m) datable("triggers_truth", triggers_truth) datable("args_truth", args_truth) datable("head_indexes", head_indexes) datable("content", content) return datable def process_test(self, dataset): datable = DataTable() for content, index, type, args, occur, triggers, id in \ tqdm(zip(dataset["content"], dataset["index"], dataset["type"], dataset["args"], dataset["occur"], dataset["triggers"], dataset["id"]), total=len(dataset["content"])): tokens_id, is_heads, head_indexes = [], [], [] # content = list(map(lambda x: str(x), content)) # words = ['[CLS]'] +content + ['[SEP]'] # for w in words: # tokens = self.tokenizer.tokenize(w) if w not in ['[CLS]', '[SEP]'] else [w] # tokens_w_id = self.tokenizer.convert_tokens_to_ids(tokens) # # if w in ['[CLS]', '[SEP]']: # # is_head = [0] # # else: # is_head = [1] + [0] * (len(tokens) - 1) # tokens_id.extend(tokens_w_id) # is_heads.extend(is_head) # token_masks = [True] * len(tokens_id) + [False] * (self.max_length - len(tokens_id)) # token_masks=token_masks[: self.max_length] # tokens_id = tokens_id + [0] * (self.max_length - len(tokens_id)) # tokens_id=tokens_id[: self.max_length] # is_heads=is_heads[: self.max_length] # for i in range(len(is_heads)): # if is_heads[i]: # head_indexes.append(i) # head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) # head_indexes=head_indexes[: self.max_length] data_content = [token.lower() for token in content] # 字符串遍历是一次取一个字,把字放在列表里面 data_content = list(data_content) # 再把这个列表强制类型转换一下,继续变成列表 inputs = self.tokenizer.encode_plus(data_content, add_special_tokens=True, max_length=self.max_length, truncation=True, padding='max_length') tokens_id, segs, token_masks = inputs["input_ids"], inputs[ "token_type_ids"], inputs['attention_mask'] head_indexes = list(np.arange(0, sum(token_masks))) head_indexes = head_indexes + [0] * (self.max_length - len(head_indexes)) head_indexes = head_indexes[:self.max_length] type_vec = np.array([0] * self.trigger_type_num) type_id = -1 if type != "<unk>": type_id = self.trigger_vocabulary.word2idx[type] for occ in occur: idx = self.trigger_vocabulary.word2idx[occ] type_vec[idx] = 1 t_m = [0] * self.max_length r_pos = list(range(-0, 0)) + [0] * (0 - 0 + 1) + list( range(1, self.max_length - 0)) r_pos = [p + self.max_length for p in r_pos] if index is not None: span = triggers[index] start_idx = span[0] + 1 end_idx = span[1] + 1 - 1 r_pos = list(range( -start_idx, 0)) + [0] * (end_idx - start_idx + 1) + list( range(1, self.max_length - end_idx)) r_pos = [p + self.max_length for p in r_pos] t_m = [0] * self.max_length t_m[start_idx] = 1 t_m[end_idx] = 1 t_index = index triggers_truth = [(span[0] + 1, span[1] + 1 - 1) for span in triggers] # 触发词起止列表改成左闭右闭 args_truth = {i: [] for i in range(self.argument_type_num)} for args_name in args: s_r_i = self.argument_vocabulary.word2idx[args_name] # s_r_i = self.args_s_id[args_name + '_s'] for span in args[args_name]: args_truth[s_r_i].append((span[0] + 1, span[1] + 1 - 1)) if type_id != -1: datable("data_ids", id) datable("type_id", type_id) datable("type_vec", type_vec) datable("tokens_id", tokens_id) datable("token_masks", token_masks) datable("t_index", t_index) datable("r_pos", r_pos) datable("t_m", t_m) datable("triggers_truth", triggers_truth) datable("args_truth", args_truth) datable("head_indexes", head_indexes) datable("content", content) return datable def get_trigger_vocabulary(self): return self.trigger_vocabulary def get_argument_vocabulary(self): return self.argument_vocabulary