示例#1
0
def load_openke_embeddings():
    """ Load OpenKE pretrained embeddings """

    entity_index: Dict[FreebaseID, int] = {}
    relation_index: Dict[FreebaseID, int] = {}

    # Entity to OpenKE ID mapping
    with utils.work_in_progress("Reading entity2id"), ENTITY_MAPPING_PATH.open() as f:
        entity_count = int(f.readline())
        for line in f:
            parts = line.split()
            freebase_id = FreebaseID(parts[0])
            entity_index[freebase_id] = int(parts[1])
        assert len(entity_index) == entity_count

    # Relation to OpenKE ID mapping
    with utils.work_in_progress("Reading relation2id"), RELATION_MAPPING_PATH.open() as f:
        relation_count = int(f.readline())
        for line in f:
            parts = line.split()
            freebase_id = FreebaseID(parts[0])
            relation_index[freebase_id] = int(parts[1])
        assert len(relation_index) == relation_count

    # Load binary vectors
    entity_vec = np.memmap(ENTITY_VEC_PATH, dtype=np.float32, mode='r')
    relation_vec = np.memmap(RELATION_VEC_PATH, dtype=np.float32, mode='r')

    return entity_index, relation_index, entity_vec, relation_vec
示例#2
0
def main():
    Logging.verbosity_level = Logging.VERBOSE

    Logging.warn("This program requires lots of memory (preferably >= 30GB).")

    if not SAVE_DIR.exists():
        SAVE_DIR.mkdir(parents=True)

    # Read the Wikimedia IDs for each article, and filter the relations
    topic_ids: Set[WikidataID] = set()
    split_title_id: Dict[str, List[Tuple[str, WikidataID]]] = {}
    for split in ['train', 'valid', 'test']:
        with utils.work_in_progress(f"Loading {split} set titles"), \
             open(TOPIC_JSON_PATH(split=split)) as f:
            j = json.load(f)
        split_title_id[split] = [(article['title'], WikidataID(article['id']))
                                 for article in j]
        topic_ids.update([wid for _, wid in split_title_id[split]])
        del j

    with utils.work_in_progress("Loading Wikidata ID mapping"):
        id2rel = load_id2str(WIKIDATA_DUMP_DIR / 'properties.txt')

    # Match the relations
    matched_dataset = read_data(ALIGNED_DATA_DIR)

    # Gather entities & relation vectors
    found_entities = set()
    found_rels = set()
    for split in matched_dataset:
        for example in matched_dataset[split]:
            found_entities.add(example.topic_id)
            for rel in example.relations:
                found_entities.add(rel.obj_id)
                found_rels.add(rel.rel_typ)
    found_entities -= {UNK_ENTITY}
    found_rels -= {NAF, ANCHOR, TOPIC_ITSELF}
    with utils.work_in_progress("Building rel vecs"):
        rel_map = load_relations(found_rels)
        rel_map.update({NAF: -1, ANCHOR: -2, TOPIC_ITSELF: -3})
        unk_rels = found_rels.difference(rel_map)
        # NOTE: unk_rels is a set, its order is undetermined, so we sort it to make sure it's consistent between runs
        for idx, rel in enumerate(sorted(unk_rels)):
            rel_map[rel] = -idx - 4  # starting from -4, going towards -inf
    with utils.work_in_progress("Building entity vecs"):
        entity_map = load_entities(found_entities)
        entity_map.update({UNK_ENTITY: -1})
        print(
            f"Topic ID coverage: {len(topic_ids.intersection(entity_map))}/{len(topic_ids)}"
        )

    # save relation type names for use during generation
    id_to_rel_name = dict(id2rel)
    id_to_rel_name.update({
        NAF: 'Not-A-Fact',
        ANCHOR: 'ANCHOR',
        TOPIC_ITSELF: 'TITLE'
    })
    rel_names: Dict[int, str] = {}
    for r_rel, rel_id in rel_map.items():
        rel_names[rel_id] = id_to_rel_name[r_rel]
    with (SAVE_DIR / 'rel_names.pkl').open('wb') as f:
        pickle.dump(rel_names, f)
        print(f"Relation names saved to {(SAVE_DIR / 'rel_names.pkl')}")

    # Convert into numbers to create the final dataset
    for split in matched_dataset:
        with utils.work_in_progress(f"Converting {split} set"):
            dataset, matched_spans = numericalize_rel(matched_dataset[split],
                                                      rel_map, entity_map)

        path = SAVE_DIR / f'{split}.pkl'
        with path.open('wb') as f:
            pickle.dump(dataset, f)
        print(
            f"Dataset split '{split}' saved to {path}, {len(dataset)} examples"
        )

        path = SAVE_DIR / f'{split}.span.pkl'
        with path.open('wb') as f:
            pickle.dump(matched_spans, f)
        print(f"Matched spans split '{split}' saved to {path}")
示例#3
0
def read_data(path: Path) -> Dict[str, List[RawExampleWikiID]]:
    bad_examples: List[Tuple[str, int, str]] = []
    data = {}
    for split in ['train', 'valid', 'test']:
        with (path / f'{split}.pkl').open('rb') as f:
            # relation tuple: (span, rel_type_desc, name, canonical_name)
            with utils.work_in_progress(f"Loading {split} set"):
                dump: List[RawDump] = pickle.load(f)

            examples = []
            for idx, (sent, rels) in enumerate(
                    utils.progress(dump, desc='Reading data')):
                # map (rel_typ, canonical) to list of aliases, since lists aren't hashable
                rel_to_alias: Dict[Tuple[str, str], List[str]] = \
                    {(rel[0][0], obj_id): alias for obj_id, _, _, rel, _, alias in rels}

                # sort it so the order is consistent
                relations: List[RelationWikiID] = sorted([
                    RelationWikiID(WikidataID(rel_id), WikidataID(obj_id),
                                   obj_alias)
                    for (rel_id, obj_id), obj_alias in rel_to_alias.items()
                ])
                rel_to_id: Dict[Tuple[str, str], int] = {
                    (rel_id, obj_id): idx
                    for idx, (rel_id, obj_id,
                              obj_alias) in enumerate(relations)
                }
                # dedup to remove duplicate (-1, -1)
                mentions: List[EntityMention] = list(
                    set(
                        EntityMention(span, surface, rel_to_id[(rel_info[0][0],
                                                                obj_id)]) for
                        obj_id, head_id, span, rel_info, surface, _ in rels))
                try:
                    # must exist - head id with the relation: @TITLE@ is the topic WikidataID
                    topic_id = next(
                        head_id
                        for _, head_id, _, rel_info, surface, alias in rels
                        if rel_info[0][0] == "@TITLE@")
                except StopIteration:
                    bad_examples.append((split, idx, ' '.join(sent)[:100]))
                    continue

                converted_relations = []
                for r in relations:
                    converted_relations.append(
                        RelationWikiID(
                            TOPIC_ITSELF if r.rel_typ == "@TITLE@" else
                            r.rel_typ, r.obj, r.obj_alias))

                example = RawExampleWikiID(WikidataID(topic_id), sent,
                                           converted_relations, mentions)
                examples.append(example)
            data[split] = examples

    if len(bad_examples) > 0:
        Logging.warn(f"{len(bad_examples)} bad examples:\n"
                     f"{pprint.pformat(bad_examples)}")
    else:
        Logging.verbose("All examples are good")

    return data
示例#4
0
def tokenize(path: Path, tokenizer: Optional[str], lang: Optional[str] = None, lowercase: bool = False,
             directory: PathType = 'data/tokenize', **tok_kwargs) -> Path:
    r"""
    Tokenize an input file or load the cached version.

    Available tokenizers and required fields:
    - `moses`: Moses from `sacremoses` package. Requires `lang`.
    - `spacy`: spaCy. Requires `lang`.
    - `spm`. SentencePiece. Requires `spm_model`: path to trained SPM model.

    :param path: Path to the input file.
    :param tokenizer: The tokenizer to use.
    :param lang: Language of the input file. Required for some tokenizers.
    :param lowercase: If true, lowercase each tokenized word.
    :param directory: Directory to store all the tokenized files.
    :param tok_kwargs: Additional arguments to pass to the tokenizer.
    :return: Path to the tokenized input file.
    """
    # shortcut to no tokenization
    if tokenizer is None and not lowercase:
        return path

    # check for invalid arguments
    valid_tokenizers = ['moses', 'spacy', 'spm']
    if tokenizer is not None and tokenizer not in valid_tokenizers:
        raise ValueError(f"Invalid tokenizer setting \"{tokenizer}\"")
    if tokenizer in ['moses', 'spacy'] and lang is None:
        raise ValueError(f"Must specify `lang` when using \"{tokenizer}\" tokenizer")
    if tokenizer == 'spm' and tok_kwargs.get('spm_model', None) is None:
        raise ValueError("Must supply `spm_model` as additional argument when using the SentencePiece tokenizer")

    # return cached file if exists
    base_path = Path(directory) / path_lca(path, directory)
    suffix = get_tokenization_args(tokenizer, lowercase)
    cached_path = path_add_suffix(base_path, suffix)
    if cached_path.exists():
        return cached_path

    # tokenize or re-use partially processed file
    tok_path = path_add_suffix(base_path, tokenizer) if tokenizer is not None else path
    if not tok_path.exists():
        with path.open('r') as f:
            sents = [line for line in f]
        with utils.work_in_progress(f"{tokenizer}: tokenizing {path}"):
            if tokenizer == 'moses':
                assert lang is not None
                tok_sents = moses_tokenize(sents, lang)
            elif tokenizer == 'spacy':
                assert lang is not None
                tok_sents = spacy_tokenize(sents, lang)
            elif tokenizer == 'spm':
                tok_sents = spm_tokenize(sents, tok_kwargs['spm_model'])
            else:
                assert False  # make the IDE happy
    else:
        with tok_path.open('r') as f:
            tok_sents = [line for line in f]

    # lowercase
    if lowercase:
        tok_sents = [[word.lower() for word in sent] for sent in tok_sents]

    # cache the file
    cached_path.resolve().parent.mkdir(parents=True, exist_ok=True)
    with cached_path.open('w') as f:
        f.write('\n'.join(' '.join(sent) for sent in tok_sents))

    return cached_path
示例#5
0
def main():
    replace_canonical = any(arg.startswith('--replace') for arg in sys.argv)
    if not replace_canonical:
        # global SAVE_DIR
        # SAVE_DIR = Path('./data/wikifacts_orig/')
        pass
    else:
        print("Arguments: Will replace the canonical forms.")
    print(f"Output directory: {SAVE_DIR}")

    skip_embeddings = any(arg.startswith('--skip') for arg in sys.argv)
    if skip_embeddings:
        print("Arguments: Will skip embedding generation.")

    id_name, id_summary, id_relations, relation_types = load_wikifacts()

    entity_index, relation_index, entity_vec, relation_vec = load_openke_embeddings()

    # Check OpenKE coverage
    print(f"Entity coverage in OpenKE: "
          f"{sum(int(fid in entity_index) for fid in id_name)}/{len(id_name)}")
    print(f"Relation coverage in OpenKE: "
          f"{sum(int(fid in relation_index) for fid in relation_types)}/{len(relation_types)}")

    """ Match entity positions and generate pickled dataset """

    # Remap entities and rel-types and store the mapping
    entity_map: Dict[FreebaseID, int] = {}
    relation_map: Dict[FreebaseID, int] = {}
    mapped_entity_vecs: List[np.ndarray] = []
    mapped_relation_vecs: List[np.ndarray] = []

    # noinspection PyShadowingNames
    def get_relation_id(rel: FreebaseID) -> int:
        rel_id = relation_map.get(rel, None)
        if rel_id is None:
            rel_id = relation_map[rel] = len(relation_map)
            mapped_relation_vecs.append(extract_vector(relation_vec, relation_index[rel]))
        return rel_id

    # noinspection PyShadowingNames
    def get_entity_id(entity: FreebaseID) -> int:
        ent_id = entity_map.get(entity, None)
        if ent_id is None:
            if entity not in entity_index:  # not all covered
                return UNK_ENTITY
            ent_id = entity_map[entity] = len(entity_map)
            mapped_entity_vecs.append(extract_vector(entity_vec, entity_index[entity]))
        return ent_id

    # Create the dataset
    dataset: List[Example] = []  # mapping of topic ID to data example
    dataset_matched_spans: List[List[MatchedSpan]] = []  # mapping of topic ID to matched spans

    # noinspection PyShadowingNames
    def find_relation(rels: List[Tuple[FreebaseID, FreebaseID]], obj: FreebaseID) -> List[FreebaseID]:
        # when there are multiple matches, just find the first one
        matched_rels = []
        for r_rel, r_obj in rels:
            if r_obj == obj:
                matched_rels.append(r_rel)
        return matched_rels

    # noinspection PyShadowingNames
    def match_positions(tokens: List[str], name: List[str]) -> List[int]:
        matched = [False] * len(name)
        positions = []
        for idx, token in enumerate(tokens):
            for match_idx, match_token in enumerate(name):
                if matched[match_idx] or match_token != token:
                    continue
                positions.append(match_idx)
                matched[match_idx] = True
                break
            else:
                return []
        return positions

    tokenizer = sacremoses.MosesTokenizer(lang='en')
    position_stats = defaultdict(int)
    for freebase_id in utils.progress(id_summary, desc='Creating dataset'):
        topic_id = get_entity_id(freebase_id)
        summary = id_summary[freebase_id].strip().split(' ')
        raw_relations = id_relations[freebase_id]
        relations = defaultdict(lambda: len(relations))
        rel_obj_names: Dict[int, str] = {}
        for r_rel, r_obj in raw_relations:
            rel_obj_names[relations[(get_relation_id(r_rel), get_entity_id(r_obj))]] = id_name[r_obj]
        topic_name = summary[0][2:summary[0].index('/')] if '/' in summary[0] else "<unknown topic name>"
        rel_obj_names[relations[(TOPIC_ITSELF, topic_id)]] = topic_name  # topic_itself
        sentence, rel_ids, copy_pos, surface_indices = [], [], [], []
        matched_spans: List[MatchedSpan] = []

        def add_words(s: str):
            tokens = tokenizer.tokenize(s, escape=False)
            sentence.extend(tokens)
            rel_ids.extend([NAF] * len(tokens))
            copy_pos.extend([-1] * len(tokens))
            surface_indices.extend([-1] * len(tokens))

        for word in summary:
            if '@@' in word:
                start_pos = word.find('@@')
                end_pos = word.find('@@', start_pos + 2)

                if start_pos > 0:
                    add_words(word[:start_pos])  # leading stuff

                entity_desc = word[(start_pos + 2):end_pos].split('/')  # there could be punctuation following '@@'
                assert len(entity_desc) >= 4  # entity name could contain '/'
                trailing = word[(end_pos + 2):]  # trailing stuff
                entity_name = '/'.join(entity_desc[:-3]).split('_')
                r_obj = FreebaseID(entity_desc[-1])
                obj_id = get_entity_id(r_obj)
                if entity_desc[-3] == 'f':
                    if r_obj == freebase_id:  # topic_itself
                        rel_id = TOPIC_ITSELF
                        matched_rels = [TOPIC_ITSELF]
                    else:
                        matched_rels = [get_relation_id(r_rel) for r_rel in find_relation(raw_relations, r_obj)]
                        # the relation might not exist
                        rel_id = None if len(matched_rels) == 0 else matched_rels[0]
                else:
                    rel_id = ANCHOR
                    matched_rels = [ANCHOR]  # include anchors anyway and filter them out if we don't use them
                    # matched_rels = []  # don't include anchors because we're only using them for our model

                if rel_id is None:  # not anymore, we allow unknown stuff: if obj_id is None or rel_id is None:
                    # no embedding for word, convert to normal token
                    add_words(' '.join(entity_name) + trailing)
                else:
                    position_stats['all'] += 1

                    if replace_canonical and r_obj in id_name:  # basically, everything except anchors
                        # just replace the entity_name, position will be in order
                        canonical_name = id_name[r_obj].split('_')
                        if canonical_name != entity_name:
                            position_stats['not_canonical'] += 1
                        entity_name = canonical_name

                    if entity_desc[-3] == 'f' and r_obj != freebase_id:
                        positions = match_positions(entity_name, id_name[r_obj].split('_'))
                        if len(positions) == 0:  # cannot match, resort to replacing
                            entity_name = id_name[r_obj].split('_')
                            positions = list(range(len(entity_name)))
                            position_stats['order'] += 1
                        else:
                            position_stats['match'] += 1
                            if positions == list(range(len(entity_name))):
                                position_stats['order'] += 1
                                position_stats['match_order'] += 1
                            elif positions == list(range(len(positions))):
                                position_stats['match_prefix'] += 1
                            elif positions == list(range(positions[0], positions[-1] + 1)):
                                position_stats['match_sub'] += 1
                    else:
                        # we don't replace canonical names for anchors (because we don't have them)
                        # anyway, these can't be used by our model
                        if entity_desc[-3] == 'a':
                            position_stats['anchor'] += 1
                            # add the anchor relation
                            rel_obj_names[relations[(ANCHOR, obj_id)]] = ' '.join(entity_name)
                        positions = list(range(len(entity_name)))  # these may not be in `id_name`, use as is
                        position_stats['order'] += 1

                    assert (rel_id, obj_id) in relations
                    rel_idx = relations[(rel_id, obj_id)]  # must exist
                    for rel_typ in matched_rels:
                        matched_spans.append((len(sentence), len(sentence) + len(entity_name), rel_typ, rel_idx, 0))
                    sentence.extend(entity_name)
                    rel_ids.extend([rel_idx] * len(entity_name))
                    copy_pos.extend(positions)
                    surface_indices.extend([0] * len(entity_name))

                    add_words(trailing)
            else:
                add_words(word)

        # we assume everything's canonical, so just add a pseudo canonical form
        rel_rev_map = {idx: rel for rel, idx in relations.items()}
        rel_list = []  # just in case `list(relations)` doesn't do as we expect
        for idx in range(len(rel_rev_map)):
            rel_id, obj_id = rel_rev_map[idx]
            rel_list.append((rel_id, obj_id, [rel_obj_names[idx].replace('_', ' ')]))
        example = (sentence, topic_id, rel_list, rel_ids, copy_pos, surface_indices)
        dataset.append(example)
        dataset_matched_spans.append(matched_spans)
    print(f"Position stats: {position_stats}")

    if replace_canonical:
        assert position_stats['all'] == position_stats['order']

    # Save them
    directory: Path = SAVE_DATASET_PATH('train').parent
    if not directory.exists():
        directory.mkdir(parents=True)

    # noinspection PyShadowingNames
    def split_dataset(dataset: List[T], splits: List[Tuple[int, int]]) -> List[List[T]]:
        dataset_size = len(dataset)
        dataset_splits = []
        for l, r in splits:
            start = int(dataset_size * (l / 100))
            end = int(dataset_size * (r / 100))
            dataset_splits.append(dataset[start:end])
        return dataset_splits

    splits = [(0, 80), (80, 90), (90, 100)]
    dataset_splits = split_dataset(dataset, splits)
    matched_spans_splits = split_dataset(dataset_matched_spans, splits)

    print(f"{len(dataset)} examples")
    for split, data, spans in zip(['train', 'valid', 'test'], dataset_splits, matched_spans_splits):
        path = SAVE_DATASET_PATH(split)
        with path.open('wb') as f:
            pickle.dump(data, f)
        print(f"Dataset split '{split}' saved to {path}, {len(data)} examples")

        path = SAVE_MATCHED_SPANS_PATH(split)
        with path.open('wb') as f:
            pickle.dump(spans, f)
        print(f"Matched spans split '{split}' saved to {path}")

    # save relation type names for use during generation
    rel_names: Dict[int, str] = {NAF: 'Not-A-Fact', ANCHOR: 'ANCHOR', TOPIC_ITSELF: 'TITLE'}
    for r_rel, rel_id in relation_map.items():
        rel_names[rel_id] = r_rel
    with (SAVE_DIR / 'rel_names.pkl').open('wb') as f:
        pickle.dump(rel_names, f)
    print("Relation names saved.")

    if not skip_embeddings:
        with utils.work_in_progress("Saving entity vecs"):
            stacked_entity_vecs = torch.from_numpy(np.stack(mapped_entity_vecs))
            torch.save(stacked_entity_vecs, SAVE_ENTITY_VEC_PATH)
        print(f"Entity vecs saved to {SAVE_ENTITY_VEC_PATH}, {len(stacked_entity_vecs)} vectors in total")

        with utils.work_in_progress("Saving relation vecs"):
            stacked_relation_vecs = torch.from_numpy(np.stack(mapped_relation_vecs))
            torch.save(stacked_relation_vecs, SAVE_RELATION_VEC_PATH)
        print(f"Relation vecs saved to {SAVE_RELATION_VEC_PATH}, {len(stacked_relation_vecs)} vectors in total")
    else:
        print("Embedding updates skipped.")

    print("Processing done.")