def load_test_examples(args, test_file, seg_len=0, seg_backoff=0): test_examples = [] for guid, text, _, _ in test_data_generator(args, test_file, seg_len=seg_len, seg_backoff=seg_backoff): # text = clean_text(text) for seg_text, in seg_generator((text, ), seg_len, seg_backoff): test_examples.append( InputExample(guid=guid, text_a=seg_text, labels=None)) logger.info(f"Loaded {len(test_examples)} test examples.") return test_examples
def load_ner_examples(data_generator, examples_file, seg_len=0, seg_backoff=0): examples = [] for guid, text_a, _, labels in data_generator(examples_file): assert text_a is not None for (seg_text_a, ), text_offset in seg_generator((text_a, ), seg_len, seg_backoff): examples.append( InputExample(guid=guid, text_a=seg_text_a, labels=labels, text_offset=text_offset)) logger.info(f"Loaded {len(examples)} examples.") return examples
def load_test_examples(test_base_file): with open(test_base_file, 'r') as fr: test_examples = [] lines = fr.readlines() for line in tqdm(lines, desc=f"test"): d = json.loads(line) guid = d['doc_id'] words = [w for w in d['content']] for seg_words, in seg_generator((words, ), seg_len, seg_backoff): test_examples.append( InputExample(guid=guid, text_a=seg_words, labels=None)) logger.info(f"Loaded {len(test_examples)} test examples.") return test_examples
def load_spo_examples(data_generator, examples_file, seg_len=0, seg_backoff=0): examples = [] for guid, text, text_b, spo_list in data_generator(examples_file): assert text is not None for (seg_text, ), text_offset in seg_generator((text, ), seg_len, seg_backoff): examples.append( InputExample(guid=guid, text=seg_text, spo_list=spo_list, text_offset=text_offset)) logger.info(f"Loaded {len(examples)} examples.") return examples
def train_data_generator(args, train_base_file, seg_len=0, seg_backoff=0): with open(train_base_file, 'r') as fr: lines = fr.readlines() for line in tqdm(lines, desc=f"train & eval"): d = json.loads(line) guid = d['doc_id'] text = d['content'] for seg_text in seg_generator((text, ), seg_len, seg_backoff): seg_text = seg_text[0] seg_labels = [] for e in d['events']: event_type = e['event_type'] # if event_type not in ['破产清算']: # ['股东减持', '股东增持']: # continue for k, v in e.items(): if not v: continue if k not in ['event_id', 'event_type']: label = '_'.join((event_type, k)) if label not in ner_labels: ner_labels.append(label) i0 = seg_text.find(v) while i0 >= 0: # if i0 >= 0: if len(v) == 1: # if labels[i0] == 'O': # labels[i0] = f"S-{label}" pass else: assert i0 < seg_len and i0 + len( v) - 1 < seg_len seg_labels.append( (label, i0, i0 + len(v) - 1)) # break i0 = seg_text.find(v, i0 + len(v)) yield guid, seg_text, None, seg_labels
def data_seg_generator(lines, ner_labels, seg_len=0, seg_backoff=0, num_augements=0, allow_overlap=False): assert seg_backoff >= 0 and seg_backoff <= int(seg_len * 3 / 4) all_text_entities = [] labels_map = {} num_overlap = 0 for i, s in enumerate(tqdm(lines)): # logger.debug(f"s: {s}") guid = str(i) text = s['text'].strip() entities = s['entities'] new_entities = [] used_span = [] # logger.debug(f"entities: {entities}") entities = sorted(entities, key=lambda e: e.start) for entity in entities: if entity.category not in ner_labels: continue entity.mention = text[entity.start:entity.end + 1] s = entity.start e = entity.end overlap = False for us in used_span: if s >= us[0] and s <= us[1]: overlap = True break if e >= us[0] and e <= us[1]: overlap = True break if overlap: num_overlap += 1 if not allow_overlap: # logger.warning( # f"Overlap! {i} mention: {entity.mention}({s}:{e}), used_span: {used_span}" # ) continue used_span.append((s, e)) new_entities.append(entity) entities = new_entities seg_offset = 0 # if seg_len <= 0: # seg_len = max_seq_length for (seg_text, ), text_offset in seg_generator((text, ), seg_len, seg_backoff): text_a = seg_text assert text_offset == seg_offset seg_start = seg_offset seg_end = seg_offset + min(seg_len, len(seg_text)) labels = [(x.category, x.start - seg_offset, x.end - seg_offset) for x in entities if x.start >= seg_offset and x.end < seg_end] # 没有标注存在的文本片断不用于训练 if labels: yield guid, text_a, None, labels, seg_offset if num_augements > 0: seg_entities = [ { 'start': x.start - seg_offset, 'end': x.end - seg_offset, 'category': x.category, 'mention': x.mention } for x in entities if x.start >= seg_offset and x.end < seg_end ] all_text_entities.append((guid, text_a, seg_entities)) for entity in seg_entities: label_type = entity['category'] s = entity['start'] # - seg_offset e = entity['end'] #- seg_offset # print(s, e) assert e >= s # logger.debug( # f"seg_start: {seg_start}, seg_end: {seg_end}, seg_offset: {seg_offset}" # ) # logger.debug(f"s: {s}, e: {e}") assert s >= 0 and e < len(seg_text) # if s >= len(seg_text) or e >= len(seg_text): # continue entity_text = seg_text[s:e + 1] # print(label_type, entity_text) assert len(entity_text) > 0 if label_type not in labels_map: labels_map[label_type] = [] labels_map[label_type].append(entity_text) seg_offset += seg_len - seg_backoff logger.warning(f"num_overlap: {num_overlap}") if num_augements > 0: aug_tokens = augement_entities(all_text_entities, labels_map, num_augements=num_augements) for guid, text, entities in aug_tokens: text_a = text for entity in entities: # logger.debug(f"text_a: {text_a}") # logger.debug( # f"text_a[entity['start']:entity['end']]: {text_a[entity['start']:entity['end']]}" # ) # logger.debug( # f"mention {entity['mention']} in {text_a.find(entity['mention'])}" # ) # logger.debug(f"entity: {entity}") assert text_a[entity['start']:entity['end'] + 1] == entity['mention'] labels = [ (entity['category'], entity['start'], entity['end']) for entity in entities if entity['end'] < ( min(len(text_a), seg_len) if seg_len > 0 else len(text_a)) ] yield guid, text_a, None, labels, 0
def load_train_eval_examples(train_base_file): label2id = {} id2label = {} train_base_examples = [] with open(train_base_file, 'r') as fr: lines = fr.readlines() for line in tqdm(lines, desc=f"train & eval"): d = json.loads(line) guid = d['doc_id'] text = d['content'] words = [w for w in text] labels = ['O'] * len(words) for e in d['events']: event_type = e['event_type'] # if event_type not in ['破产清算']: # ['股东减持', '股东增持']: # continue for k, v in e.items(): if not v: continue if k not in ['event_id', 'event_type']: label = '_'.join((event_type, k)) if label not in label2id: n = len(label2id) + 1 label2id[label] = n id2label[n] = label ner_labels.append(label) n = label2id[label] i0 = text.find(v) while i0 >= 0: # if i0 >= 0: if len(v) == 1: # if labels[i0] == 'O': # labels[i0] = f"S-{label}" pass else: labels[i0] = f"B-{label}" for j0 in range(1, len(v)): labels[i0 + j0] = f"I-{label}" i0 = text.find(v, i0 + 1) for seg_words, seg_labels in seg_generator((words, labels), seg_len, seg_backoff): train_base_examples.append( InputExample(guid=guid, text_a=seg_words, labels=seg_labels)) # if seg_len > 0: # n_segs = len(text) // seg_len # for i in range(n_segs + 1): # s0 = seg_len * i # s1 = seg_len * (i + 1) if i < n_segs - 1 else len(text) # if s0 < s1: # seg_text = text[s0:s1] # seg_words = words[s0:s1] # seg_labels = labels[s0:s1] # # train_base_examples.append( # InputExample(guid=guid, # text_a=seg_words, # labels=seg_labels)) # else: # train_base_examples.append( # InputExample(guid=guid, text_a=words, labels=labels)) # train_base_examples = train_base_examples[:100] random.shuffle(train_base_examples) train_rate = 0.9 num_eval_examples = int(len(train_base_examples) * (1 - train_rate)) num_train_samples = len(train_base_examples) - num_eval_examples if fold == 0: eval_examples = train_base_examples[num_train_samples:] train_examples = train_base_examples[:num_train_samples] else: s = num_eval_examples * (fold - 1) e = num_eval_examples * fold eval_examples = train_base_examples[s:e] train_examples = train_base_examples[:s] + train_base_examples[e:] logger.info( f"Loaded {len(train_examples)} train examples, {len(eval_examples)} eval examples." ) return train_examples, eval_examples
def load_spo_labeled_examples(data_generator, predicate_labels, seg_len=0, seg_backoff=0, num_augments=0, allow_overlap=False): examples = [] for guid, text, _, tags in data_generator(None): for (seg_text, ), seg_offset in seg_generator((text, ), seg_len, seg_backoff): # 按text_offset过滤tags seg_start = seg_offset seg_end = seg_offset + min(seg_len, len(seg_text)) # logger.info(f"tags: {tags}") # seg_tags = [ # ((x[0][0] - seg_start, x[0][1]), x[1], (x[2][0] - seg_start, # x[2][1])) for x in tags # if x[0][0] >= seg_start and x[0][0] + len(x[0][1]) < seg_end # and x[2][0] >= seg_start and x[2][0] + len(x[2][1]) < seg_end # ] seg_tags = [] for x in tags: # s_start, s_mention = x[0] # predicate = x[1] # o_start, o_mention = x[2] s_category = x['sub']['category'] s_start = x['sub']['start'] s_mention = x['sub']['mention'] predicate = x['predicate'] o_category = x['obj']['category'] o_start = x['obj']['start'] o_mention = x['obj']['mention'] # logger.warning( # f"({s_category}, {s_start}, {s_mention}), {predicate}, ({o_category}, {o_start}, {o_mention})" # ) # for seg_tag in seg_tags: # (s_start, s_mention), predicate, (o_start, o_mention) = seg_tag s_start, s_mention = fix_mention_with_blank(s_start, s_mention) if s_start < 0: continue o_start, o_mention = fix_mention_with_blank(o_start, o_mention) if o_start < 0: continue s_end = s_start + len(s_mention) o_end = o_start + len(o_mention) if s_start >= seg_start and s_end < seg_end and o_start >= seg_start and o_end < seg_end: s_start -= seg_start o_start -= seg_start s_end = s_start + len(s_mention) o_end = o_start + len(o_mention) assert seg_text[ s_start: s_end] == s_mention, f"subject tag: |{seg_text[s_start:s_end]}| != mention: |{s_mention}|. seg: ({seg_start},{seg_end}), seg_tag: {((s_start, s_mention), predicate, (o_start, o_mention))}, seg_text: {seg_text}" assert seg_text[ o_start: o_end] == o_mention, f"object tag: |{seg_text[o_start:o_end]}| != mention: |{o_mention}|. seg: (){seg_start}:{seg_end}), seg_tag: {((s_start, s_mention), predicate, (o_start, o_mention))}, seg_text: {seg_text}" seg_tags.append(((s_start, s_mention), predicate, (o_start, o_mention))) if seg_tags: examples.append( InputExample(guid=guid, text=seg_text, spo_list=seg_tags, text_offset=seg_start)) logger.info(f"Loaded {len(examples)} examples.") return examples
def data_seg_generator(lines, ner_labels, seg_len=0, seg_backoff=0, num_augments=0, allow_overlap=False): # assert seg_backoff >= 0 and seg_backoff <= int(seg_len * 3 / 4) assert seg_len >= 0 and seg_backoff >= 0 and seg_backoff <= seg_len all_text_entities = [] labels_map = {} logger.warning(f"{len(lines) in lines}") num_overlap = 0 for i, s in enumerate(tqdm(lines)): guid = str(i) text = s['text'] entities = s['entities'] new_entities = [] used_span = [] # logger.debug(f"entities: {entities}") entities = sorted(entities, key=lambda e: e.start) for entity in entities: if entity.category not in ner_labels: continue s = entity.start e = entity.end entity.mention = text[s:e + 1] overlap = False for us in used_span: if s > us[0] and s <= us[1] and e > us[1]: overlap = True logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") break if e >= us[0] and e < us[1] and s < us[0]: overlap = True logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") break # if s >= us[0] and s <= us[1]: # and e > us[1]: # overlap = True # logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") # break # if e >= us[0] and e <= us[1]: # and s < us[0]: # overlap = True # logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") # break # if us[0] >= s and us[0] <= e: # and e > us[1]: # overlap = True # logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") # break # if us[1] >= s and us[1] <= e: # and s < us[0]: # overlap = True # logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") # break # if s >= us[0] and s <= us[1]: # overlap = True # logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") # break # if e >= us[0] and e <= us[1]: # overlap = True # logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})") # break if overlap: num_overlap += 1 # if not allow_overlap: logger.warning( f"Overlap! {i} mention: {entity.mention}({s}:{e}), used_span: {used_span}" ) continue used_span.append((s, e)) new_entities.append(entity) # logger.warning(f"{len(new_entities)} new_entities") entities = new_entities seg_offset = 0 # if seg_len <= 0: # seg_len = max_seq_length for (seg_text, ), text_offset in seg_generator((text, ), seg_len, seg_backoff): text_a = seg_text assert text_offset == seg_offset seg_start = seg_offset seg_end = seg_offset + min(seg_len, len(seg_text)) labels = [(x.category, x.start - seg_offset, x.end - seg_offset) for x in entities if x.start >= seg_offset and x.end < seg_end] # 没有标注存在的文本片断不用于训练 if labels: yield guid, text_a, None, labels, seg_offset # if num_augments > 0: # seg_entities = [ # { # 'start': x.start - seg_offset, # 'end': x.end - seg_offset, # 'category': x.category, # 'mention': x.mention # } for x in entities # if x.start >= seg_offset and x.end < seg_end # ] # all_text_entities.append((guid, text_a, seg_entities)) # # for entity in seg_entities: # label_type = entity['category'] # s = entity['start'] # - seg_offset # e = entity['end'] #- seg_offset # # print(s, e) # assert e >= s # # logger.debug( # # f"seg_start: {seg_start}, seg_end: {seg_end}, seg_offset: {seg_offset}" # # ) # # logger.debug(f"s: {s}, e: {e}") # assert s >= 0 and e < len(seg_text) # # if s >= len(seg_text) or e >= len(seg_text): # # continue # # entity_text = seg_text[s:e + 1] # # print(label_type, entity_text) # # assert len(entity_text) > 0 # if label_type not in labels_map: # labels_map[label_type] = [] # labels_map[label_type].append(entity_text) seg_offset += seg_len - seg_backoff logger.warning(f"num_overlap: {num_overlap}") if num_augments > 0: aug_tokens = augment_entities(all_text_entities, labels_map, num_augments=num_augments) for guid, text, entities in aug_tokens: text_a = text for entity in entities: # logger.debug(f"text_a: {text_a}") # logger.debug( # f"text_a[entity['start']:entity['end']]: {text_a[entity['start']:entity['end']]}" # ) # logger.debug( # f"mention {entity['mention']} in {text_a.find(entity['mention'])}" # ) # logger.debug(f"entity: {entity}") # assert text_a[entity['start']:entity['end'] + 1] == entity['mention'] pass labels = [ (entity['category'], entity['start'], entity['end']) for entity in entities if entity['end'] < ( min(len(text_a), seg_len) if seg_len > 0 else len(text_a)) ] yield guid, text_a, None, labels, 0