示例#1
0
def load_test_examples(args, test_file, seg_len=0, seg_backoff=0):

    test_examples = []
    for guid, text, _, _ in test_data_generator(args,
                                                test_file,
                                                seg_len=seg_len,
                                                seg_backoff=seg_backoff):
        #  text = clean_text(text)

        for seg_text, in seg_generator((text, ), seg_len, seg_backoff):
            test_examples.append(
                InputExample(guid=guid, text_a=seg_text, labels=None))

    logger.info(f"Loaded {len(test_examples)} test examples.")
    return test_examples
示例#2
0
def load_ner_examples(data_generator, examples_file, seg_len=0, seg_backoff=0):

    examples = []

    for guid, text_a, _, labels in data_generator(examples_file):
        assert text_a is not None
        for (seg_text_a, ), text_offset in seg_generator((text_a, ), seg_len,
                                                         seg_backoff):
            examples.append(
                InputExample(guid=guid,
                             text_a=seg_text_a,
                             labels=labels,
                             text_offset=text_offset))
    logger.info(f"Loaded {len(examples)} examples.")

    return examples
示例#3
0
def load_test_examples(test_base_file):
    with open(test_base_file, 'r') as fr:

        test_examples = []
        lines = fr.readlines()
        for line in tqdm(lines, desc=f"test"):
            d = json.loads(line)
            guid = d['doc_id']
            words = [w for w in d['content']]

            for seg_words, in seg_generator((words, ), seg_len, seg_backoff):
                test_examples.append(
                    InputExample(guid=guid, text_a=seg_words, labels=None))

    logger.info(f"Loaded {len(test_examples)} test examples.")
    return test_examples
示例#4
0
def load_spo_examples(data_generator, examples_file, seg_len=0, seg_backoff=0):

    examples = []

    for guid, text, text_b, spo_list in data_generator(examples_file):
        assert text is not None
        for (seg_text, ), text_offset in seg_generator((text, ), seg_len,
                                                       seg_backoff):

            examples.append(
                InputExample(guid=guid,
                             text=seg_text,
                             spo_list=spo_list,
                             text_offset=text_offset))
    logger.info(f"Loaded {len(examples)} examples.")

    return examples
示例#5
0
def train_data_generator(args, train_base_file, seg_len=0, seg_backoff=0):
    with open(train_base_file, 'r') as fr:
        lines = fr.readlines()
        for line in tqdm(lines, desc=f"train & eval"):
            d = json.loads(line)
            guid = d['doc_id']
            text = d['content']

            for seg_text in seg_generator((text, ), seg_len, seg_backoff):
                seg_text = seg_text[0]

                seg_labels = []
                for e in d['events']:
                    event_type = e['event_type']
                    #  if event_type not in ['破产清算']:  # ['股东减持', '股东增持']:
                    #      continue
                    for k, v in e.items():
                        if not v:
                            continue

                        if k not in ['event_id', 'event_type']:
                            label = '_'.join((event_type, k))

                            if label not in ner_labels:
                                ner_labels.append(label)

                            i0 = seg_text.find(v)
                            while i0 >= 0:
                                #  if i0 >= 0:
                                if len(v) == 1:
                                    #  if labels[i0] == 'O':
                                    #      labels[i0] = f"S-{label}"
                                    pass
                                else:
                                    assert i0 < seg_len and i0 + len(
                                        v) - 1 < seg_len
                                    seg_labels.append(
                                        (label, i0, i0 + len(v) - 1))
                                #  break
                                i0 = seg_text.find(v, i0 + len(v))

                yield guid, seg_text, None, seg_labels
示例#6
0
def data_seg_generator(lines,
                       ner_labels,
                       seg_len=0,
                       seg_backoff=0,
                       num_augements=0,
                       allow_overlap=False):
    assert seg_backoff >= 0 and seg_backoff <= int(seg_len * 3 / 4)
    all_text_entities = []
    labels_map = {}

    num_overlap = 0
    for i, s in enumerate(tqdm(lines)):
        #  logger.debug(f"s: {s}")
        guid = str(i)
        text = s['text'].strip()
        entities = s['entities']

        new_entities = []
        used_span = []
        #  logger.debug(f"entities: {entities}")
        entities = sorted(entities, key=lambda e: e.start)
        for entity in entities:
            if entity.category not in ner_labels:
                continue
            entity.mention = text[entity.start:entity.end + 1]
            s = entity.start
            e = entity.end

            overlap = False
            for us in used_span:
                if s >= us[0] and s <= us[1]:
                    overlap = True
                    break
                if e >= us[0] and e <= us[1]:
                    overlap = True
                    break
            if overlap:
                num_overlap += 1
                if not allow_overlap:
                    #  logger.warning(
                    #      f"Overlap! {i} mention: {entity.mention}({s}:{e}), used_span: {used_span}"
                    #  )
                    continue
            used_span.append((s, e))

            new_entities.append(entity)
        entities = new_entities

        seg_offset = 0
        #  if seg_len <= 0:
        #      seg_len = max_seq_length

        for (seg_text, ), text_offset in seg_generator((text, ), seg_len,
                                                       seg_backoff):
            text_a = seg_text

            assert text_offset == seg_offset

            seg_start = seg_offset
            seg_end = seg_offset + min(seg_len, len(seg_text))
            labels = [(x.category, x.start - seg_offset, x.end - seg_offset)
                      for x in entities
                      if x.start >= seg_offset and x.end < seg_end]

            # 没有标注存在的文本片断不用于训练
            if labels:
                yield guid, text_a, None, labels, seg_offset

                if num_augements > 0:
                    seg_entities = [
                        {
                            'start': x.start - seg_offset,
                            'end': x.end - seg_offset,
                            'category': x.category,
                            'mention': x.mention
                        } for x in entities
                        if x.start >= seg_offset and x.end < seg_end
                    ]
                    all_text_entities.append((guid, text_a, seg_entities))

                    for entity in seg_entities:
                        label_type = entity['category']
                        s = entity['start']  # - seg_offset
                        e = entity['end']  #- seg_offset
                        #  print(s, e)
                        assert e >= s
                        #  logger.debug(
                        #      f"seg_start: {seg_start}, seg_end: {seg_end}, seg_offset: {seg_offset}"
                        #  )
                        #  logger.debug(f"s: {s}, e: {e}")
                        assert s >= 0 and e < len(seg_text)
                        #  if s >= len(seg_text) or e >= len(seg_text):
                        #      continue

                        entity_text = seg_text[s:e + 1]
                        #  print(label_type, entity_text)

                        assert len(entity_text) > 0
                        if label_type not in labels_map:
                            labels_map[label_type] = []
                        labels_map[label_type].append(entity_text)

            seg_offset += seg_len - seg_backoff

    logger.warning(f"num_overlap: {num_overlap}")

    if num_augements > 0:
        aug_tokens = augement_entities(all_text_entities,
                                       labels_map,
                                       num_augements=num_augements)
        for guid, text, entities in aug_tokens:
            text_a = text
            for entity in entities:
                #  logger.debug(f"text_a: {text_a}")
                #  logger.debug(
                #      f"text_a[entity['start']:entity['end']]: {text_a[entity['start']:entity['end']]}"
                #  )
                #  logger.debug(
                #      f"mention {entity['mention']} in {text_a.find(entity['mention'])}"
                #  )
                #  logger.debug(f"entity: {entity}")
                assert text_a[entity['start']:entity['end'] +
                              1] == entity['mention']
            labels = [
                (entity['category'], entity['start'], entity['end'])
                for entity in entities if entity['end'] < (
                    min(len(text_a), seg_len) if seg_len > 0 else len(text_a))
            ]
            yield guid, text_a, None, labels, 0
示例#7
0
def load_train_eval_examples(train_base_file):
    label2id = {}
    id2label = {}
    train_base_examples = []
    with open(train_base_file, 'r') as fr:
        lines = fr.readlines()
        for line in tqdm(lines, desc=f"train & eval"):
            d = json.loads(line)
            guid = d['doc_id']
            text = d['content']
            words = [w for w in text]
            labels = ['O'] * len(words)
            for e in d['events']:
                event_type = e['event_type']
                #  if event_type not in ['破产清算']:  # ['股东减持', '股东增持']:
                #      continue
                for k, v in e.items():
                    if not v:
                        continue

                    if k not in ['event_id', 'event_type']:
                        label = '_'.join((event_type, k))
                        if label not in label2id:
                            n = len(label2id) + 1
                            label2id[label] = n
                            id2label[n] = label
                            ner_labels.append(label)
                        n = label2id[label]
                        i0 = text.find(v)
                        while i0 >= 0:
                            #  if i0 >= 0:
                            if len(v) == 1:
                                #  if labels[i0] == 'O':
                                #      labels[i0] = f"S-{label}"
                                pass
                            else:
                                labels[i0] = f"B-{label}"
                                for j0 in range(1, len(v)):
                                    labels[i0 + j0] = f"I-{label}"
                            i0 = text.find(v, i0 + 1)

            for seg_words, seg_labels in seg_generator((words, labels),
                                                       seg_len, seg_backoff):
                train_base_examples.append(
                    InputExample(guid=guid,
                                 text_a=seg_words,
                                 labels=seg_labels))

            #  if seg_len > 0:
            #      n_segs = len(text) // seg_len
            #      for i in range(n_segs + 1):
            #          s0 = seg_len * i
            #          s1 = seg_len * (i + 1) if i < n_segs - 1 else len(text)
            #          if s0 < s1:
            #              seg_text = text[s0:s1]
            #              seg_words = words[s0:s1]
            #              seg_labels = labels[s0:s1]
            #
            #              train_base_examples.append(
            #                  InputExample(guid=guid,
            #                               text_a=seg_words,
            #                               labels=seg_labels))
            #  else:
            #      train_base_examples.append(
            #          InputExample(guid=guid, text_a=words, labels=labels))
    #  train_base_examples = train_base_examples[:100]

    random.shuffle(train_base_examples)
    train_rate = 0.9

    num_eval_examples = int(len(train_base_examples) * (1 - train_rate))
    num_train_samples = len(train_base_examples) - num_eval_examples

    if fold == 0:
        eval_examples = train_base_examples[num_train_samples:]
        train_examples = train_base_examples[:num_train_samples]
    else:
        s = num_eval_examples * (fold - 1)
        e = num_eval_examples * fold
        eval_examples = train_base_examples[s:e]
        train_examples = train_base_examples[:s] + train_base_examples[e:]
    logger.info(
        f"Loaded {len(train_examples)} train examples, {len(eval_examples)} eval examples."
    )
    return train_examples, eval_examples
示例#8
0
def load_spo_labeled_examples(data_generator,
                              predicate_labels,
                              seg_len=0,
                              seg_backoff=0,
                              num_augments=0,
                              allow_overlap=False):

    examples = []
    for guid, text, _, tags in data_generator(None):

        for (seg_text, ), seg_offset in seg_generator((text, ), seg_len,
                                                      seg_backoff):
            # 按text_offset过滤tags

            seg_start = seg_offset
            seg_end = seg_offset + min(seg_len, len(seg_text))
            #  logger.info(f"tags: {tags}")
            #  seg_tags = [
            #      ((x[0][0] - seg_start, x[0][1]), x[1], (x[2][0] - seg_start,
            #                                              x[2][1])) for x in tags
            #      if x[0][0] >= seg_start and x[0][0] + len(x[0][1]) < seg_end
            #      and x[2][0] >= seg_start and x[2][0] + len(x[2][1]) < seg_end
            #  ]
            seg_tags = []
            for x in tags:
                #  s_start, s_mention = x[0]
                #  predicate = x[1]
                #  o_start, o_mention = x[2]

                s_category = x['sub']['category']
                s_start = x['sub']['start']
                s_mention = x['sub']['mention']
                predicate = x['predicate']
                o_category = x['obj']['category']
                o_start = x['obj']['start']
                o_mention = x['obj']['mention']

                #  logger.warning(
                #      f"({s_category}, {s_start}, {s_mention}), {predicate}, ({o_category}, {o_start}, {o_mention})"
                #  )
                #  for seg_tag in seg_tags:
                #  (s_start, s_mention), predicate, (o_start, o_mention) = seg_tag

                s_start, s_mention = fix_mention_with_blank(s_start, s_mention)
                if s_start < 0:
                    continue
                o_start, o_mention = fix_mention_with_blank(o_start, o_mention)
                if o_start < 0:
                    continue
                s_end = s_start + len(s_mention)
                o_end = o_start + len(o_mention)

                if s_start >= seg_start and s_end < seg_end and o_start >= seg_start and o_end < seg_end:
                    s_start -= seg_start
                    o_start -= seg_start
                    s_end = s_start + len(s_mention)
                    o_end = o_start + len(o_mention)

                    assert seg_text[
                        s_start:
                        s_end] == s_mention, f"subject tag: |{seg_text[s_start:s_end]}| != mention: |{s_mention}|. seg: ({seg_start},{seg_end}), seg_tag: {((s_start, s_mention), predicate, (o_start, o_mention))}, seg_text: {seg_text}"
                    assert seg_text[
                        o_start:
                        o_end] == o_mention, f"object tag: |{seg_text[o_start:o_end]}| != mention: |{o_mention}|. seg: (){seg_start}:{seg_end}), seg_tag: {((s_start, s_mention), predicate, (o_start, o_mention))}, seg_text: {seg_text}"
                    seg_tags.append(((s_start, s_mention), predicate,
                                     (o_start, o_mention)))
            if seg_tags:
                examples.append(
                    InputExample(guid=guid,
                                 text=seg_text,
                                 spo_list=seg_tags,
                                 text_offset=seg_start))

    logger.info(f"Loaded {len(examples)} examples.")
    return examples
示例#9
0
def data_seg_generator(lines,
                       ner_labels,
                       seg_len=0,
                       seg_backoff=0,
                       num_augments=0,
                       allow_overlap=False):
    #  assert seg_backoff >= 0 and seg_backoff <= int(seg_len * 3 / 4)
    assert seg_len >= 0 and seg_backoff >= 0 and seg_backoff <= seg_len
    all_text_entities = []
    labels_map = {}

    logger.warning(f"{len(lines) in lines}")
    num_overlap = 0
    for i, s in enumerate(tqdm(lines)):
        guid = str(i)
        text = s['text']
        entities = s['entities']

        new_entities = []
        used_span = []
        #  logger.debug(f"entities: {entities}")
        entities = sorted(entities, key=lambda e: e.start)
        for entity in entities:
            if entity.category not in ner_labels:
                continue
            s = entity.start
            e = entity.end
            entity.mention = text[s:e + 1]

            overlap = False
            for us in used_span:
                if s > us[0] and s <= us[1] and e > us[1]:
                    overlap = True
                    logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                    break
                if e >= us[0] and e < us[1] and s < us[0]:
                    overlap = True
                    logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                    break

                #  if s >= us[0] and s <= us[1]:  # and e > us[1]:
                #      overlap = True
                #      logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                #      break
                #  if e >= us[0] and e <= us[1]:  # and s < us[0]:
                #      overlap = True
                #      logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                #      break
                #  if us[0] >= s and us[0] <= e:  # and e > us[1]:
                #      overlap = True
                #      logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                #      break
                #  if us[1] >= s and us[1] <= e:  # and s < us[0]:
                #      overlap = True
                #      logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                #      break

                #  if s >= us[0] and s <= us[1]:
                #      overlap = True
                #      logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                #      break
                #  if e >= us[0] and e <= us[1]:
                #      overlap = True
                #      logger.warning(f"Overlap: ({s},{e}) vs ({us[0],us[1]})")
                #      break
            if overlap:
                num_overlap += 1
                #  if not allow_overlap:
                logger.warning(
                    f"Overlap! {i} mention: {entity.mention}({s}:{e}), used_span: {used_span}"
                )
                continue
            used_span.append((s, e))

            new_entities.append(entity)
        #  logger.warning(f"{len(new_entities)} new_entities")
        entities = new_entities

        seg_offset = 0
        #  if seg_len <= 0:
        #      seg_len = max_seq_length

        for (seg_text, ), text_offset in seg_generator((text, ), seg_len,
                                                       seg_backoff):
            text_a = seg_text

            assert text_offset == seg_offset

            seg_start = seg_offset
            seg_end = seg_offset + min(seg_len, len(seg_text))
            labels = [(x.category, x.start - seg_offset, x.end - seg_offset)
                      for x in entities
                      if x.start >= seg_offset and x.end < seg_end]

            # 没有标注存在的文本片断不用于训练
            if labels:
                yield guid, text_a, None, labels, seg_offset

            #      if num_augments > 0:
            #          seg_entities = [
            #              {
            #                  'start': x.start - seg_offset,
            #                  'end': x.end - seg_offset,
            #                  'category': x.category,
            #                  'mention': x.mention
            #              } for x in entities
            #              if x.start >= seg_offset and x.end < seg_end
            #          ]
            #          all_text_entities.append((guid, text_a, seg_entities))
            #
            #          for entity in seg_entities:
            #              label_type = entity['category']
            #              s = entity['start']  # - seg_offset
            #              e = entity['end']  #- seg_offset
            #              #  print(s, e)
            #              assert e >= s
            #              #  logger.debug(
            #              #      f"seg_start: {seg_start}, seg_end: {seg_end}, seg_offset: {seg_offset}"
            #              #  )
            #              #  logger.debug(f"s: {s}, e: {e}")
            #              assert s >= 0 and e < len(seg_text)
            #              #  if s >= len(seg_text) or e >= len(seg_text):
            #              #      continue
            #
            #              entity_text = seg_text[s:e + 1]
            #              #  print(label_type, entity_text)
            #
            #              assert len(entity_text) > 0
            #              if label_type not in labels_map:
            #                  labels_map[label_type] = []
            #              labels_map[label_type].append(entity_text)

            seg_offset += seg_len - seg_backoff

    logger.warning(f"num_overlap: {num_overlap}")

    if num_augments > 0:
        aug_tokens = augment_entities(all_text_entities,
                                      labels_map,
                                      num_augments=num_augments)
        for guid, text, entities in aug_tokens:
            text_a = text
            for entity in entities:
                #  logger.debug(f"text_a: {text_a}")
                #  logger.debug(
                #      f"text_a[entity['start']:entity['end']]: {text_a[entity['start']:entity['end']]}"
                #  )
                #  logger.debug(
                #      f"mention {entity['mention']} in {text_a.find(entity['mention'])}"
                #  )
                #  logger.debug(f"entity: {entity}")

                #  assert text_a[entity['start']:entity['end'] + 1] == entity['mention']
                pass

            labels = [
                (entity['category'], entity['start'], entity['end'])
                for entity in entities if entity['end'] < (
                    min(len(text_a), seg_len) if seg_len > 0 else len(text_a))
            ]
            yield guid, text_a, None, labels, 0