Ejemplo n.º 1
0
def find_label_pos(label, sys_utt_tokens, usr_utt_tokens):
    for _label in [label] + SEMANTIC_DICT.get(label, []):
        _label_tokens = utils.tokenize(_label)
        sys_label_pos = list(utils.find_sublist(sys_utt_tokens, _label_tokens))
        usr_label_pos = list(utils.find_sublist(usr_utt_tokens, _label_tokens))

        if sys_label_pos or usr_label_pos:
            return _label, sys_label_pos, usr_label_pos

    return label, sys_label_pos, usr_label_pos
Ejemplo n.º 2
0
def advanced_tagger(sentence):
    NP = X_extractor(sentence, 'NP')
    tok = nltk.word_tokenize(sentence)
    tag = nltk.pos_tag(tok)
    for np in NP:
        index = utils.find_sublist(nltk.word_tokenize(np), tok)[0]  #have trouble if occur multiple times
        tok = tok[:index[0]] + [np] + tok[index[1]+1:]
        tag = tag[:index[0]] + [(np, 'NP')] +tag[index[1]+1:]
     
    return (tok, tag)
Ejemplo n.º 3
0
 def get_test_entities(self):
     """ Returns dictionary of test entity indices by type """
     print('\nTEST')
     test_entities = {'PER': [], 'LOC': [], 'ORG': [], 'MISC': []}
     #for triplet in test list
     for sent, pos_tags, indices in self.test_data:
         for key in self.known_entities:
             for ne in self.known_entities[key]:
                 matches = find_sublist(ne, sent)
                 #DEBUG('-> matches = {}'.format(matches))
                 for begin, end in matches:
                     tag = {
                         'string': ' '.join(sent[begin:end + 1]),
                         'indices': (indices[begin], indices[end])
                     }
                     test_entities[key].append(tag)
     return test_entities
Ejemplo n.º 4
0
def _load_flickr30k(dataroot, img_id2idx, bbox, pos_boxes):
    """Load entries

    img_id2idx: dict {img_id -> val} val can be used to retrieve image or features
    dataroot: root path of dataset
    name: 'train', 'val', 'test-dev2015', test2015'
    """
    pattern_phrase = r'\[(.*?)\]'
    pattern_no = r'\/EN\#(\d+)'

    missing_entity_count = dict()
    multibox_entity_count = 0

    entries = []
    for image_id, idx in img_id2idx.items():

        phrase_file = os.path.join(
            dataroot, 'Flickr30kEntities/Sentences/%d.txt' % image_id)
        anno_file = os.path.join(
            dataroot, 'Flickr30kEntities/Annotations/%d.xml' % image_id)

        with open(phrase_file, 'r', encoding='utf-8') as f:
            sents = [x.strip() for x in f]

        # Parse Annotation
        root = parse(anno_file).getroot()
        obj_elems = root.findall('./object')
        pos_box = pos_boxes[idx]
        bboxes = bbox[pos_box[0]:pos_box[1]]
        target_bboxes = {}

        for elem in obj_elems:
            if elem.find('bndbox') == None or len(elem.find('bndbox')) == 0:
                continue
            left = int(elem.findtext('./bndbox/xmin'))
            top = int(elem.findtext('./bndbox/ymin'))
            right = int(elem.findtext('./bndbox/xmax'))
            bottom = int(elem.findtext('./bndbox/ymax'))
            assert 0 < left and 0 < top

            for name in elem.findall('name'):
                entity_id = int(name.text)
                assert 0 < entity_id
                if not entity_id in target_bboxes:
                    target_bboxes[entity_id] = []
                else:
                    multibox_entity_count += 1
                target_bboxes[entity_id].append([left, top, right, bottom])

        # Parse Sentence
        for sent_id, sent in enumerate(sents):
            sentence = utils.remove_annotations(sent)
            entities = re.findall(pattern_phrase, sent)
            entity_indices = []
            target_indices = []
            entity_ids = []
            entity_types = []

            for entity_i, entity in enumerate(entities):
                info, phrase = entity.split(' ', 1)
                entity_id = int(re.findall(pattern_no, info)[0])
                entity_type = info.split('/')[2:]

                entity_idx = utils.find_sublist(sentence.split(' '),
                                                phrase.split(' '))
                assert 0 <= entity_idx

                if not entity_id in target_bboxes:
                    if entity_id >= 0:
                        missing_entity_count[
                            entity_type[0]] = missing_entity_count.get(
                                entity_type[0], 0) + 1
                    continue

                assert 0 < entity_id

                entity_ids.append(entity_id)
                entity_types.append(entity_type)

                target_idx = utils.get_match_index(target_bboxes[entity_id],
                                                   bboxes)
                entity_indices.append(entity_idx)
                target_indices.append(target_idx)

            if 0 == len(entity_ids):
                continue

            entries.append(
                _create_flickr_entry(idx, sentence, entity_indices,
                                     target_indices, entity_ids, entity_types))

    if 0 < len(missing_entity_count.keys()):
        print('missing_entity_count=')
        print(missing_entity_count)
        print('multibox_entity_count=%d' % multibox_entity_count)

    return entries
Ejemplo n.º 5
0
def _load_kairos(dataset,
                 img_id2idx,
                 bbox,
                 pos_boxes,
                 topic_doc_json,
                 topic=None):
    """Load entries

    img_id2idx: dict {img_id -> val} val can be used to retrieve image or features
    dataroot: root path of dataset
    name: 'train', 'val', 'test-dev2015', test2015'
    """
    pattern_phrase = r'\[\/EN\#(.*?)\]'
    pattern_no = r'\/EN\#(\d+)'

    missing_entity_count = dict()
    multibox_entity_count = 0

    entries = []

    if topic is None:
        topic = dataset

    for image_id, idx in img_id2idx.items():

        # anno_file = f'data/{dataset}/annotations/{image_id}.xml'

        for phrase_id in topic_doc_json[topic]:

            phrase_file = f'data/{dataset}/json_output/ent_sents/{phrase_id}.txt'

            with open(phrase_file, 'r', encoding='utf-8') as f:
                sents = [x.strip() for x in f]

            # Parse Annotation
            # root = parse(anno_file).getroot()
            # obj_elems = root.findall('./object')
            # pos_box = pos_boxes[idx]
            # bboxes = bbox[pos_box[0]:pos_box[1]]
            # target_bboxes = {}

            # for elem in obj_elems:
            #     if elem.find('bndbox') == None or len(elem.find('bndbox')) == 0:
            #         continue
            #     left = int(elem.findtext('./bndbox/xmin'))
            #     top = int(elem.findtext('./bndbox/ymin'))
            #     right = int(elem.findtext('./bndbox/xmax'))
            #     bottom = int(elem.findtext('./bndbox/ymax'))
            #     assert 0 <= left and 0 <= bottom, f"[{left}, {top}, {right}, {bottom}]"

            #     for name in elem.findall('name'):
            #         entity_id = int(name.text)
            #         assert 0 < entity_id
            #         if not entity_id in target_bboxes:
            #             target_bboxes[entity_id] = []
            #         else:
            #             multibox_entity_count += 1
            #         target_bboxes[entity_id].append([left, top, right, bottom])

            # Parse Sentence
            for sent_id, sent in enumerate(sents):
                sentence = utils.remove_annotations(sent)
                entities = re.findall(pattern_phrase, sent)
                entity_indices = []
                # target_indices = []
                entity_ids = []
                entity_types = []
                # pdb.set_trace()

                for entity_i, entity in enumerate(entities):
                    entity = "/EN#" + entity
                    info, phrase = entity.split(' ', 1)
                    try:
                        entity_id = int(re.findall(pattern_no, info)[0])
                    except:
                        print(
                            f"entity_id = {entity_id}, entity = {entity} \nsentence = {sentence}, info = {info}"
                        )
                        raise Exception("entry creation failed")
                    entity_type = info.split('/')[2:]

                    entity_idx = utils.find_sublist(sentence.split(' '),
                                                    phrase.split(' '))
                    try:
                        assert 0 <= entity_idx, f"entity_idx = {entity_idx}, entity = {phrase} \nsentence = {sentence}, info = {info}"
                    except:
                        continue

                    # if not entity_id in target_bboxes:
                    #     if entity_id >= 0:
                    #         missing_entity_count[entity_type[0]] = missing_entity_count.get(entity_type[0], 0) + 1

                    assert 0 < entity_id

                    entity_ids.append(entity_id)
                    entity_types.append(entity_type)

                    # target_idx = utils.get_match_index(target_bboxes[entity_id], bboxes)
                    entity_indices.append(entity_idx)
                    # target_indices.append(target_idx)

                if 0 == len(entity_ids):
                    continue
                try:
                    entry = _create_kairos_entry(idx,
                                                 f"{phrase_id}-s{sent_id}",
                                                 sentence, entity_indices,
                                                 entity_ids, entity_types)
                except:
                    print(idx, sent_id, sentence, sent)
                    raise Exception("entry creation failed")

                entries.append(entry)

    if 0 < len(missing_entity_count.keys()):
        print('missing_entity_count=')
        print(missing_entity_count)
        print('multibox_entity_count=%d' % multibox_entity_count)

    return entries