Ejemplo n.º 1
0
def evaluate(verbose=False):
    entries = Entry.objects(set='train')

    # evaluation
    num, dem = 0, 0

    for entry in entries:
        entity_map, predicates = utils.map_entities(entry.triples)

        for lexEntry in entry.texts:
            local_num = 0
            dem += len(entity_map.keys())

            template = lexEntry.template
            for tag in entity_map:
                if tag in template:
                    local_num += 1
                    num += 1

            if local_num != len(entity_map.keys()) and verbose:
                print entity_map
                print lexEntry.template
                print 20 * '-'

    print 'Evaluation: ', (float(num) / dem)
Ejemplo n.º 2
0
    def extract_template(self, triples):
        # entity and predicate mapping
        entitymap, predicates = utils.map_entities(triples=triples)

        trainentries = Entry.objects(size=len(triples), set='train')
        for i, triple in enumerate(triples):
            trainentries = filter(lambda entry: entry.triples[i].predicate.name == triple.predicate.name, trainentries)

        # extract templates
        templates = []
        for entry in trainentries:
            for lexEntry in entry.texts:
                template = lexEntry.template

                entitiesPresence = True
                for tag in entitymap:
                    if tag not in template:
                        entitiesPresence = False
                        break
                if entitiesPresence:
                    templates.append(template)

        templates = nltk.FreqDist(templates)
        item = sorted(templates.items(), key=operator.itemgetter(1), reverse=True)
        if len(item) == 0:
            template, freq = '', 0
        else:
            template, freq = item[0]
            # REPLACE ENTITY TAGS FOR WIKIPEDIA IDs
            for tag, entity in sorted(entitymap.items(), key=lambda x: len(x[1].name), reverse=True):
                template = template.replace(tag, '_'.join(entity.name.replace('\'', '').replace('\"', '').split()))
        return template, entitymap, predicates
Ejemplo n.º 3
0
    def run(self, entry, _set):
        print 10 * '-'
        print 'ID:', str(entry.docid), str(entry.size), str(entry.category)

        if _set == 'dev':
            # extract references (gold standard)
            self.extract_gold_standard(entry.texts)

        # ordering triples
        triples = entry.triples
        semcategory = entry.category
        striples = self.order_process(triples, semcategory)

        # Templates selection
        templates = []
        for triples in striples:
            templates.extend(self.template_process(triples[0], semcategory))

        templates = sorted(templates,
                           key=lambda template: template[1],
                           reverse=True)[:self.beam]

        # Referring expression generation
        entitymap, predicates = utils.map_entities(entry.triples)
        templates = map(lambda template: ' '.join(template[0]), templates)
        templates = self.reg_process(templates, triples[0], entitymap)

        # Ranking with KenLM
        templates = sorted(templates,
                           key=lambda x: self.model.score(x),
                           reverse=True)
        if len(templates) > 0:
            template = templates[0]
        else:
            template = ''

        if _set == 'dev':
            self.hyps.append(template.strip())
        else:
            self.hyps_test.append(template.strip())

        print 'Entities: ', str(
            map(lambda x: (x[0], x[1].name), entitymap.items()))
        print 'Predicate: ', str(
            map(lambda predicate: predicate.name, predicates))
        print template.encode('utf-8')
        print 10 * '-'
        return template
Ejemplo n.º 4
0
    def process(self, entry):
        '''
        :param entry:
        :return:
        '''
        self.entry = entry
        entitymap, predicates = utils.map_entities(self.entry.triples)

        training_set = []
        for lex in self.entry.texts:
            template = lex.template
            delex_type = lex.delex_type

            if self.check_tagfrequency(entitymap, template):
                sort_triples, triples = [], copy.deepcopy(entry.triples)
                out = self.proc.parse_doc(template)

                prev_tags = []
                for i, snt in enumerate(out['sentences']):
                    tags = []

                    # get tags in order
                    for token in snt['tokens']:
                        if token in entitymap:
                            tags.append(token)

                    # Ordering the triples in the sentence i
                    sort_snt_triples, triples = self.order(
                        triples, entitymap, prev_tags, tags)
                    sort_triples.extend(sort_snt_triples)

                # Extract template for the sentence
                if len(triples) == 0:
                    template = []
                    for snt in out['sentences']:
                        template.extend(snt['tokens'])
                    template = self.generate_template(sort_triples, template,
                                                      entitymap)
                    training_set.append({
                        'sorted_triples': sort_triples,
                        'triples': entry.triples,
                        'template': template,
                        'lexEntry': lex,
                        'semcategory': entry.category,
                        'delex_type': delex_type
                    })
        return training_set
Ejemplo n.º 5
0
 def generate_template(self, triples, template, entitymap):
     '''
     :param triples:
     :param template:
     :param entitymap:
     :return:
     '''
     new_entitymap, predicates = utils.map_entities(triples)
     new_entitymap = dict(
         map(lambda x: (x[1].name, x[0]), new_entitymap.items()))
     new_template = []
     for token in template:
         if token in entitymap:
             new_template.append(new_entitymap[entitymap[token].name])
         else:
             new_template.append(token)
     return ' '.join(new_template).replace('-LRB-',
                                           '(').replace('-RRB-',
                                                        ')').strip()
Ejemplo n.º 6
0
def get_parallel(set, delex=True, size=10, evaluation=False):
    entries = Entry.objects(size__lte=size, set=set)
    proc = CoreNLP('ssplit')

    de, en, entity_maps = [], [], []
    for entry in entries:
        entity_map, predicates = utils.map_entities(entry.triples)
        entity2tag = utils.entity2tag(entity_map)

        source = ''
        for triple in entry.triples:
            agent = triple.agent.name
            tag_agent = entity2tag[agent]

            predicate = triple.predicate.name

            patient = triple.patient.name
            tag_patient = entity2tag[patient]

            if delex:
                source += tag_agent
            else:
                source += agent
            source += ' '
            source += predicate
            source += ' '
            if delex:
                source += tag_patient
            else:
                source += patient
            source += ' '

            if not DELEX and set in ['train', 'dev'] and not evaluation:
                de.append(agent)
                name = ' '.join(
                    agent.replace('\'', '').replace('\"', '').split('_'))
                out = proc.parse_doc(name)
                text = ''
                for snt in out['sentences']:
                    text += ' '.join(snt['tokens']).replace('-LRB-',
                                                            '(').replace(
                                                                '-RRB-', ')')
                    text += ' '
                en.append(text.strip())

                de.append(patient)
                name = ' '.join(
                    patient.replace('\'', '').replace('\"', '').split('_'))
                out = proc.parse_doc(name)
                text = ''
                for snt in out['sentences']:
                    text += ' '.join(snt['tokens']).replace('-LRB-',
                                                            '(').replace(
                                                                '-RRB-', ')')
                    text += ' '
                en.append(text.strip())

        target_list = []
        for lexEntry in entry.texts:
            if delex and not evaluation:
                target = lexEntry.template
            else:
                target = lexEntry.text
            out = proc.parse_doc(target)

            text = ''
            for snt in out['sentences']:
                text += ' '.join(snt['tokens']).replace('-LRB-', '(').replace(
                    '-RRB-', ')')
                text += ' '
            target = text.strip()
            target_list.append(target)

            print source
            print target
            print 10 * '-'
            if not evaluation:
                entity_maps.append(entity_map)
                de.append(source.strip())
                en.append(target)
        if evaluation:
            entity_maps.append(entity_map)
            de.append(source.strip())
            en.append(target_list)
        elif set == 'test':
            entity_maps.append(entity_map)
            de.append(source.strip())
    return de, en, entity_maps
Ejemplo n.º 7
0
    def _extract_template(self, triples, semcategory=''):
        '''
        Extract templates based on the triple set
        :param self:
        :param triples: triple set
        :return: templates, tag->entity map and predicates
        '''
        # entity and predicate mapping
        entitymap, predicates = utils.map_entities(triples=triples)

        # Select templates that have the same predicates (in the same order??) than the input triples
        if self.delex_type in ['automatic', 'manual']:
            if semcategory == '':
                train_templates = Template.objects(
                    Q(triples__size=len(triples))
                    & Q(delex_type=self.delex_type))
            else:
                train_templates = Template.objects(
                    Q(triples__size=len(triples))
                    & Q(delex_type=self.delex_type) & Q(category=semcategory))
        else:
            if semcategory == '':
                train_templates = Template.objects(triples__size=len(triples))
            else:
                train_templates = Template.objects(
                    Q(triples__size=len(triples)) & Q(category=semcategory))

        for i, triple in enumerate(triples):
            train_templates = filter(
                lambda train_template: train_template.triples[i].predicate.name
                == triple.predicate.name, train_templates)

        # extract templates
        templates = []
        for entry in train_templates:
            template = entry.template

            entitiesPresence = True
            for tag in entitymap:
                if tag not in template:
                    entitiesPresence = False
                    break
            if entitiesPresence:
                templates.append(template)

        templates = nltk.FreqDist(templates)
        templates = sorted(templates.items(),
                           key=operator.itemgetter(1),
                           reverse=True)

        new_templates = []
        dem = sum(map(lambda item: item[1], templates))
        for item in templates:
            template, freq = item
            # REPLACE ENTITY TAGS FOR WIKIPEDIA IDs
            for tag, entity in sorted(entitymap.items(),
                                      key=lambda x: len(x[1].name),
                                      reverse=True):
                template = template.replace(
                    tag, '_'.join(
                        entity.name.replace('\'', '').replace('\"',
                                                              '').split()))
            new_templates.append((template, float(freq) / dem))

        if len(new_templates) == 0 and semcategory != '':
            new_templates, entitymap, predicates = self._extract_template(
                triples, '')

        return new_templates, entitymap, predicates
Ejemplo n.º 8
0
def generate_re_test_file(
        ehr_record: HealthRecord,
        max_len: int = 128) -> Tuple[List[str], List[Relation]]:
    """
    Generates test file for Relation Extraction.

    Parameters
    -----------
    ehr_record : HealthRecord
        The EHR record with entities set.

    max_len : int
        The maximum length of sequence.

    Returns
    --------
    Tuple[List[str], List[Relation]]
        List of sequences with entity replaced by it's tag.
        And a list of relation objects representing relation in those sequences.
    """
    random.seed(0)

    re_text_list = []
    relation_list = []

    text = ehr_record.text
    entities = ehr_record.get_entities()
    if isinstance(entities, dict):
        entities = list(entities.values())

    # get character split points
    char_split_points = get_char_split_points(ehr_record, max_len)

    start = 0
    end = char_split_points[0]

    for i in range(len(char_split_points)):
        # Obtain only entities within the split text
        range_entities = [
            ent for ent in filter(
                lambda item: int(item[0]) >= start and int(item[1]) <= end,
                entities)
        ]

        # Get all possible relations within the split text
        possible_relations = utils.map_entities(range_entities)

        for rel, label in possible_relations:
            split_text = text[start:end]
            split_offset = start

            ent1 = rel.get_entities()[0]
            ent2 = rel.get_entities()[1]

            # Check if both entities are within split text
            if ent1[0] >= start and ent1[1] < end and \
                    ent2[0] >= start and ent2[1] < end:

                modified_text = replace_entity_text(split_text, ent1, ent2,
                                                    split_offset)

                # Replace un-required characters with space
                final_text = modified_text.replace('\n',
                                                   ' ').replace('\t', ' ')

                re_text_list.append(final_text)
                relation_list.append(rel)

        start = end
        if i != len(char_split_points) - 1:
            end = char_split_points[i + 1]
        else:
            end = len(text) + 1

    assert len(re_text_list) == len(relation_list)

    return re_text_list, relation_list
Ejemplo n.º 9
0
def generate_re_input_files(ehr_records: List[HealthRecord],
                            filename: str,
                            ade_records: List[Dict] = None,
                            max_len: int = 128,
                            is_test=False,
                            is_label=True,
                            is_predict=False,
                            sep: str = '\t'):

    random.seed(0)

    index = 0
    index_rel_label_map = []

    with open(filename, 'w') as file:
        # Write headers
        write_file(file, 'index', 'sentence', 'label', sep, is_test, is_label)

        # Preprocess EHR records
        for record in ehr_records:
            text = record.text
            entities = record.get_entities()

            if is_predict:
                true_relations = None
            else:
                true_relations = record.get_relations()

            # get character split points
            char_split_points = get_char_split_points(record, max_len)

            start = 0
            end = char_split_points[0]

            for i in range(len(char_split_points)):
                # Obtain only entities within the split text
                range_entities = {
                    ent_id: ent
                    for ent_id, ent in filter(
                        lambda item: int(item[1][0]) >= start and int(item[1][
                            1]) <= end, entities.items())
                }

                # Get all possible relations within the split text
                possible_relations = utils.map_entities(
                    range_entities, true_relations)

                for rel, label in possible_relations:
                    if label == 0 and rel.name != "ADE-Drug":
                        if random.random() > 0.25:
                            continue

                    split_text = text[start:end]
                    split_offset = start

                    ent1 = rel.get_entities()[0]
                    ent2 = rel.get_entities()[1]

                    # Check if both entities are within split text
                    if ent1.range[0] >= start and ent1.range[1] < end and \
                            ent2.range[0] >= start and ent2.range[1] < end:

                        modified_text = replace_entity_text(
                            split_text, ent1, ent2, split_offset)

                        # Replace un-required characters with space
                        final_text = modified_text.replace('\n', ' ').replace(
                            '\t', ' ')
                        write_file(file, index, final_text, label, sep,
                                   is_test, is_label)

                        if is_predict:
                            index_rel_label_map.append({'relation': rel})
                        else:
                            index_rel_label_map.append({
                                'label': label,
                                'relation': rel
                            })

                        index += 1

                start = end
                if i != len(char_split_points) - 1:
                    end = char_split_points[i + 1]
                else:
                    end = len(text) + 1

        # Preprocess ADE records
        if ade_records is not None:
            for record in ade_records:
                entities = record['entities']
                true_relations = record['relations']
                possible_relations = utils.map_entities(
                    entities, true_relations)

                for rel, label in possible_relations:

                    if label == 1 and random.random() > 0.5:
                        continue

                    new_tokens = record['tokens'].copy()

                    for ent in rel.get_entities():
                        ent_type = ent.name

                        start_tok = ent.range[0]
                        end_tok = ent.range[1] + 1

                        for i in range(start_tok, end_tok):
                            new_tokens[i] = '@' + ent_type + '$'
                    """Remove consecutive repeating entities.
                    Eg. this is @ADE$ @ADE$ @ADE$ for @Drug$ @Drug$ -> this is @ADE$ for @Drug$"""
                    final_tokens = [new_tokens[i] for i in range(len(new_tokens))\
                                    if (i == 0) or new_tokens[i] != new_tokens[i-1]]

                    final_text = " ".join(final_tokens)

                    write_file(file, index, final_text, label, sep, is_test,
                               is_label)
                    index_rel_label_map.append({
                        'label': label,
                        'relation': rel
                    })
                    index += 1

    filename, ext = filename.split('.')
    utils.save_pickle(filename + '_rel.pkl', index_rel_label_map)