コード例 #1
0
    def create_batches(self, batch_size: int, bptt_size: int):
        r"""A general routine to create batches of specified batch size and BPTT length.

        :param batch_size: The number of examples in one batch.
        :param bptt_size: The length for truncated-backprop, i.e. the maximum length of sentences in one batch.
        """
        self.batches = {}
        self.ntokens = {}
        for split, raw_dataset in self.data.items():
            ntokens = 0
            # sort the data by document length
            parts = sorted(raw_dataset, key=len)
            num_batches = utils.ceil_div(len(parts), batch_size)
            all_batches = []
            for batch_idx in utils.progress(num_batches,
                                            desc="Creating batches",
                                            ascii=True,
                                            ncols=80):
                part = parts[(batch_idx * batch_size):((batch_idx + 1) *
                                                       batch_size)]
                init_batch, batches = self.create_one_batch(part, bptt_size)
                ntokens += sum(batch.ntokens for batch in batches)
                all_batches.append((init_batch, batches))
            self.batches[split] = all_batches
            self.ntokens[split] = ntokens

        unk_probs = self.unk_probs
        if unk_probs is not None:
            total_w2i = self.total_w2i
            for split, dataset in self.batches.items():
                dataset = utils.progress(
                    dataset,
                    ncols=80,
                    desc=f"Adding unk vocab for {split} set",
                    ascii=True)
                for _, batches in dataset:
                    for batch in batches:
                        batch.add_unk_probs(unk_probs, total_w2i)
コード例 #2
0
    def average_copies(example,
                       start_symbol,
                       max_length=200,
                       n_tries=10,
                       progress=False,
                       show_samples=False,
                       **kwargs):
        complete_copies = utils.SimpleAverage()
        incomplete_copies = utils.SimpleAverage()

        if show_samples:
            sample_kwargs = dict(print_info=True,
                                 color_outputs=True,
                                 color_incomplete=False,
                                 **kwargs)
        else:
            sample_kwargs = dict(print_info=False,
                                 color_outputs=False,
                                 **kwargs)

        for _ in utils.progress(n_tries, verbose=progress):
            output: SampledOutput = model.sampling_decode(
                dataset.vocab,
                example,
                begin_symbol=start_symbol,
                end_symbol=end_symbol,
                max_length=max_length,
                **sample_kwargs)
            complete_copies.add(output.complete_copies)
            incomplete_copies.add(output.incomplete_copies)
            if show_samples:
                print(' '.join(output.sentence))
        if isinstance(example, LRLMExample):
            n_gold_rels = sum(
                int(span.end < max_length) for span in example.spans)
            print(
                f"Complete / Gold entities: {complete_copies.value()} / {n_gold_rels}"
            )
        else:
            print(
                f"Complete / Incomplete entities: {complete_copies.value()} / {incomplete_copies.value()}"
            )
コード例 #3
0
def read_data(path: Path) -> Dict[str, List[RawExampleWikiID]]:
    bad_examples: List[Tuple[str, int, str]] = []
    data = {}
    for split in ['train', 'valid', 'test']:
        with (path / f'{split}.pkl').open('rb') as f:
            # relation tuple: (span, rel_type_desc, name, canonical_name)
            with utils.work_in_progress(f"Loading {split} set"):
                dump: List[RawDump] = pickle.load(f)

            examples = []
            for idx, (sent, rels) in enumerate(
                    utils.progress(dump, desc='Reading data')):
                # map (rel_typ, canonical) to list of aliases, since lists aren't hashable
                rel_to_alias: Dict[Tuple[str, str], List[str]] = \
                    {(rel[0][0], obj_id): alias for obj_id, _, _, rel, _, alias in rels}

                # sort it so the order is consistent
                relations: List[RelationWikiID] = sorted([
                    RelationWikiID(WikidataID(rel_id), WikidataID(obj_id),
                                   obj_alias)
                    for (rel_id, obj_id), obj_alias in rel_to_alias.items()
                ])
                rel_to_id: Dict[Tuple[str, str], int] = {
                    (rel_id, obj_id): idx
                    for idx, (rel_id, obj_id,
                              obj_alias) in enumerate(relations)
                }
                # dedup to remove duplicate (-1, -1)
                mentions: List[EntityMention] = list(
                    set(
                        EntityMention(span, surface, rel_to_id[(rel_info[0][0],
                                                                obj_id)]) for
                        obj_id, head_id, span, rel_info, surface, _ in rels))
                try:
                    # must exist - head id with the relation: @TITLE@ is the topic WikidataID
                    topic_id = next(
                        head_id
                        for _, head_id, _, rel_info, surface, alias in rels
                        if rel_info[0][0] == "@TITLE@")
                except StopIteration:
                    bad_examples.append((split, idx, ' '.join(sent)[:100]))
                    continue

                converted_relations = []
                for r in relations:
                    converted_relations.append(
                        RelationWikiID(
                            TOPIC_ITSELF if r.rel_typ == "@TITLE@" else
                            r.rel_typ, r.obj, r.obj_alias))

                example = RawExampleWikiID(WikidataID(topic_id), sent,
                                           converted_relations, mentions)
                examples.append(example)
            data[split] = examples

    if len(bad_examples) > 0:
        Logging.warn(f"{len(bad_examples)} bad examples:\n"
                     f"{pprint.pformat(bad_examples)}")
    else:
        Logging.verbose("All examples are good")

    return data
コード例 #4
0
    def _extra_init(self, loaded_batches: bool):
        self.rel_vocab = Vocab.from_dict(self._path / 'rel_names.pkl',
                                         mode='i2w')
        self.vocab: Dict[str, Vocab] = {
            "word": self.word_vocab,
            "rel": self.rel_vocab
        }

        self.max_unkrel = max(
            (-rel_typ - 3 for rel_typ in self.rel_vocab.i2w if rel_typ < -3),
            default=0)

        if self._use_fasttext:

            def _alias_path(name):
                path = Path(self._fasttext_model_path)
                return path.parent / (path.name + f'.{name}')

            # gather all entity aliases and compute fastText embeddings
            alias_dict_path = _alias_path('alias_dict.pkl')
            if alias_dict_path.exists():
                alias_dict: Dict[str, int] = loadpkl(alias_dict_path)
                loaded = True
            else:
                alias_dict = defaultdict(lambda: len(alias_dict))
                loaded = False
            if not loaded_batches:
                for dataset in self.data.values():
                    for example in dataset:
                        for idx, rel in enumerate(
                                example.relations):  # type: ignore
                            example.relations[
                                idx] = rel._replace(  # type: ignore
                                    obj_alias=[
                                        alias_dict[s] for s in rel.obj_alias
                                    ])
            if not alias_dict_path.exists():
                alias_dict = dict(alias_dict)
                savepkl(alias_dict, alias_dict_path)

            alias_vectors_path = _alias_path('alias_vectors.pt')
            if not alias_vectors_path.exists() or not loaded:
                import fastText
                ft_model = fastText.load_model(self._fasttext_model_path)
                alias_vectors = []
                alias_list = utils.reverse_map(alias_dict)
                for alias in utils.progress(alias_list,
                                            desc="Building fastText vectors",
                                            ascii=True,
                                            ncols=80):
                    vectors = [
                        ft_model.get_word_vector(w) for w in alias.split()
                    ]
                    vectors = np.sum(vectors, axis=0).tolist()
                    alias_vectors.append(vectors)
                alias_vectors = torch.tensor(alias_vectors)
                torch.save(alias_vectors, alias_vectors_path)

        if not loaded_batches and (self._exclude_entity_disamb
                                   or self._exclude_alias_disamb):
            # no need to do this if batches are loaded
            if self._exclude_entity_disamb:
                # gather training set stats
                self.entity_count_per_type = self.gather_entity_stats(
                    self.data['train'])

            for dataset in self.data.values():
                for idx in range(len(dataset)):
                    dataset[idx] = self.remove_ambiguity(
                        dataset[idx], self._exclude_entity_disamb,
                        self._exclude_alias_disamb)
コード例 #5
0
ファイル: process_wikifacts.py プロジェクト: jlacomis/lrlm
def load_wikifacts():
    """ Load the WikiFacts dataset """
    id_name: Dict[FreebaseID, str] = {}
    id_summary: Dict[FreebaseID, str] = {}
    id_relations: Dict[FreebaseID, List[Relation]] = {}
    relation_types: Set[FreebaseID] = set()

    def add_or_check(fid: FreebaseID, name: str):
        name = name.lower()
        if fid in id_name:
            assert name == id_name[fid]
        else:
            id_name[fid] = name

    tar = tarfile.open(DATASET_PATH)

    id_fb_raw: Dict[FreebaseID, str] = {}
    id_en_raw: Dict[FreebaseID, str] = {}

    # Annotated Wikipedia summary (.sm)
    for summary_file in utils.progress(tar.getmembers(), desc='Reading summary (.sm)'):
        if not summary_file.name.endswith('.sm'):
            continue
        freebase_id = extract_id(summary_file)
        with tar.extractfile(summary_file) as f:
            id_summary[freebase_id] = f.read().decode('utf-8')

    # Cache Freebase Topic files (.fb, .en)
    for file in utils.progress(tar.getmembers(), desc='Caching topic files'):
        freebase_id = extract_id(file)
        if file.name.endswith('.fb'):
            with tar.extractfile(file) as f:
                id_fb_raw[freebase_id] = f.read().decode('utf-8')
        elif file.name.endswith('en'):
            with tar.extractfile(file) as f:
                id_en_raw[freebase_id] = f.read().decode('utf-8')

    # Freebase Topic (.fb, .en)
    for freebase_id in utils.progress(id_fb_raw, desc='Extracting relations'):
        relations = []
        f_fb = id_fb_raw[freebase_id].split('\n')
        f_en = id_en_raw[freebase_id].split('\n')
        for rel_line, name_line in zip(f_fb, f_en):
            parts = rel_line.split()
            if len(parts) == 0:  # empty line
                continue
            rel_name = name_line.strip().split(' ')  # do not split on non-breaking or unicode spaces
            assert len(parts) == len(rel_name)
            if len(parts) == 5:  # composite value type
                if parts[0].startswith('['):  # subject is CVT
                    continue
                parts = [parts[0], parts[3], parts[4].rstrip(']')]
                if parts[0] == parts[-1]:  # don't include a relation with itself
                    continue
                # remove the final ']', but do not use `.rstrip` because the name could contain ']'
                rel_name = [rel_name[0], rel_name[3], rel_name[4][:-1]]
            elif len(parts) == 3:  # simple relation
                pass
            else:
                raise ValueError  # malformed data
            rel = [remove_prefix(r) for r in parts]
            r_sub, r_rel, r_obj = rel
            if r_sub != freebase_id:  # only keep relations whose subject matches topic
                continue
            add_or_check(r_sub, rel_name[0])
            add_or_check(r_obj, rel_name[2])
            relations.append((r_rel, r_obj))
            relation_types.add(r_rel)
        id_relations[freebase_id] = relations

    # # Freebase to Wikidata mappings
    # freebase_wikidata_map = {}
    # with utils.progress(open(FREEBASE_MAPPING_PATH), desc='Reading fb2w') as f:
    #     for line in f:
    #         if line.startswith('#') or line.strip() == '':
    #             continue
    #         parts = line.split()
    #         freebase_id = parts[0].split('/')[-1][:-1]
    #         wikidata_id = parts[2].split('/')[-1][:-1]
    #         freebase_wikidata_map[freebase_id] = wikidata_id
    #
    # # Check coverage
    # print(f"id_name coverage in fb2w: "
    #       f"{sum(int(fid in freebase_wikidata_map) for fid in id_name)}/{len(id_name)}")
    # print(f"id_summary coverage in fb2w: "
    #       f"{sum(int(fid in freebase_wikidata_map) for fid in id_summary)}/{len(id_summary)}")

    return id_name, id_summary, id_relations, relation_types
コード例 #6
0
ファイル: process_wikifacts.py プロジェクト: jlacomis/lrlm
def main():
    replace_canonical = any(arg.startswith('--replace') for arg in sys.argv)
    if not replace_canonical:
        # global SAVE_DIR
        # SAVE_DIR = Path('./data/wikifacts_orig/')
        pass
    else:
        print("Arguments: Will replace the canonical forms.")
    print(f"Output directory: {SAVE_DIR}")

    skip_embeddings = any(arg.startswith('--skip') for arg in sys.argv)
    if skip_embeddings:
        print("Arguments: Will skip embedding generation.")

    id_name, id_summary, id_relations, relation_types = load_wikifacts()

    entity_index, relation_index, entity_vec, relation_vec = load_openke_embeddings()

    # Check OpenKE coverage
    print(f"Entity coverage in OpenKE: "
          f"{sum(int(fid in entity_index) for fid in id_name)}/{len(id_name)}")
    print(f"Relation coverage in OpenKE: "
          f"{sum(int(fid in relation_index) for fid in relation_types)}/{len(relation_types)}")

    """ Match entity positions and generate pickled dataset """

    # Remap entities and rel-types and store the mapping
    entity_map: Dict[FreebaseID, int] = {}
    relation_map: Dict[FreebaseID, int] = {}
    mapped_entity_vecs: List[np.ndarray] = []
    mapped_relation_vecs: List[np.ndarray] = []

    # noinspection PyShadowingNames
    def get_relation_id(rel: FreebaseID) -> int:
        rel_id = relation_map.get(rel, None)
        if rel_id is None:
            rel_id = relation_map[rel] = len(relation_map)
            mapped_relation_vecs.append(extract_vector(relation_vec, relation_index[rel]))
        return rel_id

    # noinspection PyShadowingNames
    def get_entity_id(entity: FreebaseID) -> int:
        ent_id = entity_map.get(entity, None)
        if ent_id is None:
            if entity not in entity_index:  # not all covered
                return UNK_ENTITY
            ent_id = entity_map[entity] = len(entity_map)
            mapped_entity_vecs.append(extract_vector(entity_vec, entity_index[entity]))
        return ent_id

    # Create the dataset
    dataset: List[Example] = []  # mapping of topic ID to data example
    dataset_matched_spans: List[List[MatchedSpan]] = []  # mapping of topic ID to matched spans

    # noinspection PyShadowingNames
    def find_relation(rels: List[Tuple[FreebaseID, FreebaseID]], obj: FreebaseID) -> List[FreebaseID]:
        # when there are multiple matches, just find the first one
        matched_rels = []
        for r_rel, r_obj in rels:
            if r_obj == obj:
                matched_rels.append(r_rel)
        return matched_rels

    # noinspection PyShadowingNames
    def match_positions(tokens: List[str], name: List[str]) -> List[int]:
        matched = [False] * len(name)
        positions = []
        for idx, token in enumerate(tokens):
            for match_idx, match_token in enumerate(name):
                if matched[match_idx] or match_token != token:
                    continue
                positions.append(match_idx)
                matched[match_idx] = True
                break
            else:
                return []
        return positions

    tokenizer = sacremoses.MosesTokenizer(lang='en')
    position_stats = defaultdict(int)
    for freebase_id in utils.progress(id_summary, desc='Creating dataset'):
        topic_id = get_entity_id(freebase_id)
        summary = id_summary[freebase_id].strip().split(' ')
        raw_relations = id_relations[freebase_id]
        relations = defaultdict(lambda: len(relations))
        rel_obj_names: Dict[int, str] = {}
        for r_rel, r_obj in raw_relations:
            rel_obj_names[relations[(get_relation_id(r_rel), get_entity_id(r_obj))]] = id_name[r_obj]
        topic_name = summary[0][2:summary[0].index('/')] if '/' in summary[0] else "<unknown topic name>"
        rel_obj_names[relations[(TOPIC_ITSELF, topic_id)]] = topic_name  # topic_itself
        sentence, rel_ids, copy_pos, surface_indices = [], [], [], []
        matched_spans: List[MatchedSpan] = []

        def add_words(s: str):
            tokens = tokenizer.tokenize(s, escape=False)
            sentence.extend(tokens)
            rel_ids.extend([NAF] * len(tokens))
            copy_pos.extend([-1] * len(tokens))
            surface_indices.extend([-1] * len(tokens))

        for word in summary:
            if '@@' in word:
                start_pos = word.find('@@')
                end_pos = word.find('@@', start_pos + 2)

                if start_pos > 0:
                    add_words(word[:start_pos])  # leading stuff

                entity_desc = word[(start_pos + 2):end_pos].split('/')  # there could be punctuation following '@@'
                assert len(entity_desc) >= 4  # entity name could contain '/'
                trailing = word[(end_pos + 2):]  # trailing stuff
                entity_name = '/'.join(entity_desc[:-3]).split('_')
                r_obj = FreebaseID(entity_desc[-1])
                obj_id = get_entity_id(r_obj)
                if entity_desc[-3] == 'f':
                    if r_obj == freebase_id:  # topic_itself
                        rel_id = TOPIC_ITSELF
                        matched_rels = [TOPIC_ITSELF]
                    else:
                        matched_rels = [get_relation_id(r_rel) for r_rel in find_relation(raw_relations, r_obj)]
                        # the relation might not exist
                        rel_id = None if len(matched_rels) == 0 else matched_rels[0]
                else:
                    rel_id = ANCHOR
                    matched_rels = [ANCHOR]  # include anchors anyway and filter them out if we don't use them
                    # matched_rels = []  # don't include anchors because we're only using them for our model

                if rel_id is None:  # not anymore, we allow unknown stuff: if obj_id is None or rel_id is None:
                    # no embedding for word, convert to normal token
                    add_words(' '.join(entity_name) + trailing)
                else:
                    position_stats['all'] += 1

                    if replace_canonical and r_obj in id_name:  # basically, everything except anchors
                        # just replace the entity_name, position will be in order
                        canonical_name = id_name[r_obj].split('_')
                        if canonical_name != entity_name:
                            position_stats['not_canonical'] += 1
                        entity_name = canonical_name

                    if entity_desc[-3] == 'f' and r_obj != freebase_id:
                        positions = match_positions(entity_name, id_name[r_obj].split('_'))
                        if len(positions) == 0:  # cannot match, resort to replacing
                            entity_name = id_name[r_obj].split('_')
                            positions = list(range(len(entity_name)))
                            position_stats['order'] += 1
                        else:
                            position_stats['match'] += 1
                            if positions == list(range(len(entity_name))):
                                position_stats['order'] += 1
                                position_stats['match_order'] += 1
                            elif positions == list(range(len(positions))):
                                position_stats['match_prefix'] += 1
                            elif positions == list(range(positions[0], positions[-1] + 1)):
                                position_stats['match_sub'] += 1
                    else:
                        # we don't replace canonical names for anchors (because we don't have them)
                        # anyway, these can't be used by our model
                        if entity_desc[-3] == 'a':
                            position_stats['anchor'] += 1
                            # add the anchor relation
                            rel_obj_names[relations[(ANCHOR, obj_id)]] = ' '.join(entity_name)
                        positions = list(range(len(entity_name)))  # these may not be in `id_name`, use as is
                        position_stats['order'] += 1

                    assert (rel_id, obj_id) in relations
                    rel_idx = relations[(rel_id, obj_id)]  # must exist
                    for rel_typ in matched_rels:
                        matched_spans.append((len(sentence), len(sentence) + len(entity_name), rel_typ, rel_idx, 0))
                    sentence.extend(entity_name)
                    rel_ids.extend([rel_idx] * len(entity_name))
                    copy_pos.extend(positions)
                    surface_indices.extend([0] * len(entity_name))

                    add_words(trailing)
            else:
                add_words(word)

        # we assume everything's canonical, so just add a pseudo canonical form
        rel_rev_map = {idx: rel for rel, idx in relations.items()}
        rel_list = []  # just in case `list(relations)` doesn't do as we expect
        for idx in range(len(rel_rev_map)):
            rel_id, obj_id = rel_rev_map[idx]
            rel_list.append((rel_id, obj_id, [rel_obj_names[idx].replace('_', ' ')]))
        example = (sentence, topic_id, rel_list, rel_ids, copy_pos, surface_indices)
        dataset.append(example)
        dataset_matched_spans.append(matched_spans)
    print(f"Position stats: {position_stats}")

    if replace_canonical:
        assert position_stats['all'] == position_stats['order']

    # Save them
    directory: Path = SAVE_DATASET_PATH('train').parent
    if not directory.exists():
        directory.mkdir(parents=True)

    # noinspection PyShadowingNames
    def split_dataset(dataset: List[T], splits: List[Tuple[int, int]]) -> List[List[T]]:
        dataset_size = len(dataset)
        dataset_splits = []
        for l, r in splits:
            start = int(dataset_size * (l / 100))
            end = int(dataset_size * (r / 100))
            dataset_splits.append(dataset[start:end])
        return dataset_splits

    splits = [(0, 80), (80, 90), (90, 100)]
    dataset_splits = split_dataset(dataset, splits)
    matched_spans_splits = split_dataset(dataset_matched_spans, splits)

    print(f"{len(dataset)} examples")
    for split, data, spans in zip(['train', 'valid', 'test'], dataset_splits, matched_spans_splits):
        path = SAVE_DATASET_PATH(split)
        with path.open('wb') as f:
            pickle.dump(data, f)
        print(f"Dataset split '{split}' saved to {path}, {len(data)} examples")

        path = SAVE_MATCHED_SPANS_PATH(split)
        with path.open('wb') as f:
            pickle.dump(spans, f)
        print(f"Matched spans split '{split}' saved to {path}")

    # save relation type names for use during generation
    rel_names: Dict[int, str] = {NAF: 'Not-A-Fact', ANCHOR: 'ANCHOR', TOPIC_ITSELF: 'TITLE'}
    for r_rel, rel_id in relation_map.items():
        rel_names[rel_id] = r_rel
    with (SAVE_DIR / 'rel_names.pkl').open('wb') as f:
        pickle.dump(rel_names, f)
    print("Relation names saved.")

    if not skip_embeddings:
        with utils.work_in_progress("Saving entity vecs"):
            stacked_entity_vecs = torch.from_numpy(np.stack(mapped_entity_vecs))
            torch.save(stacked_entity_vecs, SAVE_ENTITY_VEC_PATH)
        print(f"Entity vecs saved to {SAVE_ENTITY_VEC_PATH}, {len(stacked_entity_vecs)} vectors in total")

        with utils.work_in_progress("Saving relation vecs"):
            stacked_relation_vecs = torch.from_numpy(np.stack(mapped_relation_vecs))
            torch.save(stacked_relation_vecs, SAVE_RELATION_VEC_PATH)
        print(f"Relation vecs saved to {SAVE_RELATION_VEC_PATH}, {len(stacked_relation_vecs)} vectors in total")
    else:
        print("Embedding updates skipped.")

    print("Processing done.")