def load_openke_embeddings(): """ Load OpenKE pretrained embeddings """ entity_index: Dict[FreebaseID, int] = {} relation_index: Dict[FreebaseID, int] = {} # Entity to OpenKE ID mapping with utils.work_in_progress("Reading entity2id"), ENTITY_MAPPING_PATH.open() as f: entity_count = int(f.readline()) for line in f: parts = line.split() freebase_id = FreebaseID(parts[0]) entity_index[freebase_id] = int(parts[1]) assert len(entity_index) == entity_count # Relation to OpenKE ID mapping with utils.work_in_progress("Reading relation2id"), RELATION_MAPPING_PATH.open() as f: relation_count = int(f.readline()) for line in f: parts = line.split() freebase_id = FreebaseID(parts[0]) relation_index[freebase_id] = int(parts[1]) assert len(relation_index) == relation_count # Load binary vectors entity_vec = np.memmap(ENTITY_VEC_PATH, dtype=np.float32, mode='r') relation_vec = np.memmap(RELATION_VEC_PATH, dtype=np.float32, mode='r') return entity_index, relation_index, entity_vec, relation_vec
def main(): Logging.verbosity_level = Logging.VERBOSE Logging.warn("This program requires lots of memory (preferably >= 30GB).") if not SAVE_DIR.exists(): SAVE_DIR.mkdir(parents=True) # Read the Wikimedia IDs for each article, and filter the relations topic_ids: Set[WikidataID] = set() split_title_id: Dict[str, List[Tuple[str, WikidataID]]] = {} for split in ['train', 'valid', 'test']: with utils.work_in_progress(f"Loading {split} set titles"), \ open(TOPIC_JSON_PATH(split=split)) as f: j = json.load(f) split_title_id[split] = [(article['title'], WikidataID(article['id'])) for article in j] topic_ids.update([wid for _, wid in split_title_id[split]]) del j with utils.work_in_progress("Loading Wikidata ID mapping"): id2rel = load_id2str(WIKIDATA_DUMP_DIR / 'properties.txt') # Match the relations matched_dataset = read_data(ALIGNED_DATA_DIR) # Gather entities & relation vectors found_entities = set() found_rels = set() for split in matched_dataset: for example in matched_dataset[split]: found_entities.add(example.topic_id) for rel in example.relations: found_entities.add(rel.obj_id) found_rels.add(rel.rel_typ) found_entities -= {UNK_ENTITY} found_rels -= {NAF, ANCHOR, TOPIC_ITSELF} with utils.work_in_progress("Building rel vecs"): rel_map = load_relations(found_rels) rel_map.update({NAF: -1, ANCHOR: -2, TOPIC_ITSELF: -3}) unk_rels = found_rels.difference(rel_map) # NOTE: unk_rels is a set, its order is undetermined, so we sort it to make sure it's consistent between runs for idx, rel in enumerate(sorted(unk_rels)): rel_map[rel] = -idx - 4 # starting from -4, going towards -inf with utils.work_in_progress("Building entity vecs"): entity_map = load_entities(found_entities) entity_map.update({UNK_ENTITY: -1}) print( f"Topic ID coverage: {len(topic_ids.intersection(entity_map))}/{len(topic_ids)}" ) # save relation type names for use during generation id_to_rel_name = dict(id2rel) id_to_rel_name.update({ NAF: 'Not-A-Fact', ANCHOR: 'ANCHOR', TOPIC_ITSELF: 'TITLE' }) rel_names: Dict[int, str] = {} for r_rel, rel_id in rel_map.items(): rel_names[rel_id] = id_to_rel_name[r_rel] with (SAVE_DIR / 'rel_names.pkl').open('wb') as f: pickle.dump(rel_names, f) print(f"Relation names saved to {(SAVE_DIR / 'rel_names.pkl')}") # Convert into numbers to create the final dataset for split in matched_dataset: with utils.work_in_progress(f"Converting {split} set"): dataset, matched_spans = numericalize_rel(matched_dataset[split], rel_map, entity_map) path = SAVE_DIR / f'{split}.pkl' with path.open('wb') as f: pickle.dump(dataset, f) print( f"Dataset split '{split}' saved to {path}, {len(dataset)} examples" ) path = SAVE_DIR / f'{split}.span.pkl' with path.open('wb') as f: pickle.dump(matched_spans, f) print(f"Matched spans split '{split}' saved to {path}")
def read_data(path: Path) -> Dict[str, List[RawExampleWikiID]]: bad_examples: List[Tuple[str, int, str]] = [] data = {} for split in ['train', 'valid', 'test']: with (path / f'{split}.pkl').open('rb') as f: # relation tuple: (span, rel_type_desc, name, canonical_name) with utils.work_in_progress(f"Loading {split} set"): dump: List[RawDump] = pickle.load(f) examples = [] for idx, (sent, rels) in enumerate( utils.progress(dump, desc='Reading data')): # map (rel_typ, canonical) to list of aliases, since lists aren't hashable rel_to_alias: Dict[Tuple[str, str], List[str]] = \ {(rel[0][0], obj_id): alias for obj_id, _, _, rel, _, alias in rels} # sort it so the order is consistent relations: List[RelationWikiID] = sorted([ RelationWikiID(WikidataID(rel_id), WikidataID(obj_id), obj_alias) for (rel_id, obj_id), obj_alias in rel_to_alias.items() ]) rel_to_id: Dict[Tuple[str, str], int] = { (rel_id, obj_id): idx for idx, (rel_id, obj_id, obj_alias) in enumerate(relations) } # dedup to remove duplicate (-1, -1) mentions: List[EntityMention] = list( set( EntityMention(span, surface, rel_to_id[(rel_info[0][0], obj_id)]) for obj_id, head_id, span, rel_info, surface, _ in rels)) try: # must exist - head id with the relation: @TITLE@ is the topic WikidataID topic_id = next( head_id for _, head_id, _, rel_info, surface, alias in rels if rel_info[0][0] == "@TITLE@") except StopIteration: bad_examples.append((split, idx, ' '.join(sent)[:100])) continue converted_relations = [] for r in relations: converted_relations.append( RelationWikiID( TOPIC_ITSELF if r.rel_typ == "@TITLE@" else r.rel_typ, r.obj, r.obj_alias)) example = RawExampleWikiID(WikidataID(topic_id), sent, converted_relations, mentions) examples.append(example) data[split] = examples if len(bad_examples) > 0: Logging.warn(f"{len(bad_examples)} bad examples:\n" f"{pprint.pformat(bad_examples)}") else: Logging.verbose("All examples are good") return data
def tokenize(path: Path, tokenizer: Optional[str], lang: Optional[str] = None, lowercase: bool = False, directory: PathType = 'data/tokenize', **tok_kwargs) -> Path: r""" Tokenize an input file or load the cached version. Available tokenizers and required fields: - `moses`: Moses from `sacremoses` package. Requires `lang`. - `spacy`: spaCy. Requires `lang`. - `spm`. SentencePiece. Requires `spm_model`: path to trained SPM model. :param path: Path to the input file. :param tokenizer: The tokenizer to use. :param lang: Language of the input file. Required for some tokenizers. :param lowercase: If true, lowercase each tokenized word. :param directory: Directory to store all the tokenized files. :param tok_kwargs: Additional arguments to pass to the tokenizer. :return: Path to the tokenized input file. """ # shortcut to no tokenization if tokenizer is None and not lowercase: return path # check for invalid arguments valid_tokenizers = ['moses', 'spacy', 'spm'] if tokenizer is not None and tokenizer not in valid_tokenizers: raise ValueError(f"Invalid tokenizer setting \"{tokenizer}\"") if tokenizer in ['moses', 'spacy'] and lang is None: raise ValueError(f"Must specify `lang` when using \"{tokenizer}\" tokenizer") if tokenizer == 'spm' and tok_kwargs.get('spm_model', None) is None: raise ValueError("Must supply `spm_model` as additional argument when using the SentencePiece tokenizer") # return cached file if exists base_path = Path(directory) / path_lca(path, directory) suffix = get_tokenization_args(tokenizer, lowercase) cached_path = path_add_suffix(base_path, suffix) if cached_path.exists(): return cached_path # tokenize or re-use partially processed file tok_path = path_add_suffix(base_path, tokenizer) if tokenizer is not None else path if not tok_path.exists(): with path.open('r') as f: sents = [line for line in f] with utils.work_in_progress(f"{tokenizer}: tokenizing {path}"): if tokenizer == 'moses': assert lang is not None tok_sents = moses_tokenize(sents, lang) elif tokenizer == 'spacy': assert lang is not None tok_sents = spacy_tokenize(sents, lang) elif tokenizer == 'spm': tok_sents = spm_tokenize(sents, tok_kwargs['spm_model']) else: assert False # make the IDE happy else: with tok_path.open('r') as f: tok_sents = [line for line in f] # lowercase if lowercase: tok_sents = [[word.lower() for word in sent] for sent in tok_sents] # cache the file cached_path.resolve().parent.mkdir(parents=True, exist_ok=True) with cached_path.open('w') as f: f.write('\n'.join(' '.join(sent) for sent in tok_sents)) return cached_path
def main(): replace_canonical = any(arg.startswith('--replace') for arg in sys.argv) if not replace_canonical: # global SAVE_DIR # SAVE_DIR = Path('./data/wikifacts_orig/') pass else: print("Arguments: Will replace the canonical forms.") print(f"Output directory: {SAVE_DIR}") skip_embeddings = any(arg.startswith('--skip') for arg in sys.argv) if skip_embeddings: print("Arguments: Will skip embedding generation.") id_name, id_summary, id_relations, relation_types = load_wikifacts() entity_index, relation_index, entity_vec, relation_vec = load_openke_embeddings() # Check OpenKE coverage print(f"Entity coverage in OpenKE: " f"{sum(int(fid in entity_index) for fid in id_name)}/{len(id_name)}") print(f"Relation coverage in OpenKE: " f"{sum(int(fid in relation_index) for fid in relation_types)}/{len(relation_types)}") """ Match entity positions and generate pickled dataset """ # Remap entities and rel-types and store the mapping entity_map: Dict[FreebaseID, int] = {} relation_map: Dict[FreebaseID, int] = {} mapped_entity_vecs: List[np.ndarray] = [] mapped_relation_vecs: List[np.ndarray] = [] # noinspection PyShadowingNames def get_relation_id(rel: FreebaseID) -> int: rel_id = relation_map.get(rel, None) if rel_id is None: rel_id = relation_map[rel] = len(relation_map) mapped_relation_vecs.append(extract_vector(relation_vec, relation_index[rel])) return rel_id # noinspection PyShadowingNames def get_entity_id(entity: FreebaseID) -> int: ent_id = entity_map.get(entity, None) if ent_id is None: if entity not in entity_index: # not all covered return UNK_ENTITY ent_id = entity_map[entity] = len(entity_map) mapped_entity_vecs.append(extract_vector(entity_vec, entity_index[entity])) return ent_id # Create the dataset dataset: List[Example] = [] # mapping of topic ID to data example dataset_matched_spans: List[List[MatchedSpan]] = [] # mapping of topic ID to matched spans # noinspection PyShadowingNames def find_relation(rels: List[Tuple[FreebaseID, FreebaseID]], obj: FreebaseID) -> List[FreebaseID]: # when there are multiple matches, just find the first one matched_rels = [] for r_rel, r_obj in rels: if r_obj == obj: matched_rels.append(r_rel) return matched_rels # noinspection PyShadowingNames def match_positions(tokens: List[str], name: List[str]) -> List[int]: matched = [False] * len(name) positions = [] for idx, token in enumerate(tokens): for match_idx, match_token in enumerate(name): if matched[match_idx] or match_token != token: continue positions.append(match_idx) matched[match_idx] = True break else: return [] return positions tokenizer = sacremoses.MosesTokenizer(lang='en') position_stats = defaultdict(int) for freebase_id in utils.progress(id_summary, desc='Creating dataset'): topic_id = get_entity_id(freebase_id) summary = id_summary[freebase_id].strip().split(' ') raw_relations = id_relations[freebase_id] relations = defaultdict(lambda: len(relations)) rel_obj_names: Dict[int, str] = {} for r_rel, r_obj in raw_relations: rel_obj_names[relations[(get_relation_id(r_rel), get_entity_id(r_obj))]] = id_name[r_obj] topic_name = summary[0][2:summary[0].index('/')] if '/' in summary[0] else "<unknown topic name>" rel_obj_names[relations[(TOPIC_ITSELF, topic_id)]] = topic_name # topic_itself sentence, rel_ids, copy_pos, surface_indices = [], [], [], [] matched_spans: List[MatchedSpan] = [] def add_words(s: str): tokens = tokenizer.tokenize(s, escape=False) sentence.extend(tokens) rel_ids.extend([NAF] * len(tokens)) copy_pos.extend([-1] * len(tokens)) surface_indices.extend([-1] * len(tokens)) for word in summary: if '@@' in word: start_pos = word.find('@@') end_pos = word.find('@@', start_pos + 2) if start_pos > 0: add_words(word[:start_pos]) # leading stuff entity_desc = word[(start_pos + 2):end_pos].split('/') # there could be punctuation following '@@' assert len(entity_desc) >= 4 # entity name could contain '/' trailing = word[(end_pos + 2):] # trailing stuff entity_name = '/'.join(entity_desc[:-3]).split('_') r_obj = FreebaseID(entity_desc[-1]) obj_id = get_entity_id(r_obj) if entity_desc[-3] == 'f': if r_obj == freebase_id: # topic_itself rel_id = TOPIC_ITSELF matched_rels = [TOPIC_ITSELF] else: matched_rels = [get_relation_id(r_rel) for r_rel in find_relation(raw_relations, r_obj)] # the relation might not exist rel_id = None if len(matched_rels) == 0 else matched_rels[0] else: rel_id = ANCHOR matched_rels = [ANCHOR] # include anchors anyway and filter them out if we don't use them # matched_rels = [] # don't include anchors because we're only using them for our model if rel_id is None: # not anymore, we allow unknown stuff: if obj_id is None or rel_id is None: # no embedding for word, convert to normal token add_words(' '.join(entity_name) + trailing) else: position_stats['all'] += 1 if replace_canonical and r_obj in id_name: # basically, everything except anchors # just replace the entity_name, position will be in order canonical_name = id_name[r_obj].split('_') if canonical_name != entity_name: position_stats['not_canonical'] += 1 entity_name = canonical_name if entity_desc[-3] == 'f' and r_obj != freebase_id: positions = match_positions(entity_name, id_name[r_obj].split('_')) if len(positions) == 0: # cannot match, resort to replacing entity_name = id_name[r_obj].split('_') positions = list(range(len(entity_name))) position_stats['order'] += 1 else: position_stats['match'] += 1 if positions == list(range(len(entity_name))): position_stats['order'] += 1 position_stats['match_order'] += 1 elif positions == list(range(len(positions))): position_stats['match_prefix'] += 1 elif positions == list(range(positions[0], positions[-1] + 1)): position_stats['match_sub'] += 1 else: # we don't replace canonical names for anchors (because we don't have them) # anyway, these can't be used by our model if entity_desc[-3] == 'a': position_stats['anchor'] += 1 # add the anchor relation rel_obj_names[relations[(ANCHOR, obj_id)]] = ' '.join(entity_name) positions = list(range(len(entity_name))) # these may not be in `id_name`, use as is position_stats['order'] += 1 assert (rel_id, obj_id) in relations rel_idx = relations[(rel_id, obj_id)] # must exist for rel_typ in matched_rels: matched_spans.append((len(sentence), len(sentence) + len(entity_name), rel_typ, rel_idx, 0)) sentence.extend(entity_name) rel_ids.extend([rel_idx] * len(entity_name)) copy_pos.extend(positions) surface_indices.extend([0] * len(entity_name)) add_words(trailing) else: add_words(word) # we assume everything's canonical, so just add a pseudo canonical form rel_rev_map = {idx: rel for rel, idx in relations.items()} rel_list = [] # just in case `list(relations)` doesn't do as we expect for idx in range(len(rel_rev_map)): rel_id, obj_id = rel_rev_map[idx] rel_list.append((rel_id, obj_id, [rel_obj_names[idx].replace('_', ' ')])) example = (sentence, topic_id, rel_list, rel_ids, copy_pos, surface_indices) dataset.append(example) dataset_matched_spans.append(matched_spans) print(f"Position stats: {position_stats}") if replace_canonical: assert position_stats['all'] == position_stats['order'] # Save them directory: Path = SAVE_DATASET_PATH('train').parent if not directory.exists(): directory.mkdir(parents=True) # noinspection PyShadowingNames def split_dataset(dataset: List[T], splits: List[Tuple[int, int]]) -> List[List[T]]: dataset_size = len(dataset) dataset_splits = [] for l, r in splits: start = int(dataset_size * (l / 100)) end = int(dataset_size * (r / 100)) dataset_splits.append(dataset[start:end]) return dataset_splits splits = [(0, 80), (80, 90), (90, 100)] dataset_splits = split_dataset(dataset, splits) matched_spans_splits = split_dataset(dataset_matched_spans, splits) print(f"{len(dataset)} examples") for split, data, spans in zip(['train', 'valid', 'test'], dataset_splits, matched_spans_splits): path = SAVE_DATASET_PATH(split) with path.open('wb') as f: pickle.dump(data, f) print(f"Dataset split '{split}' saved to {path}, {len(data)} examples") path = SAVE_MATCHED_SPANS_PATH(split) with path.open('wb') as f: pickle.dump(spans, f) print(f"Matched spans split '{split}' saved to {path}") # save relation type names for use during generation rel_names: Dict[int, str] = {NAF: 'Not-A-Fact', ANCHOR: 'ANCHOR', TOPIC_ITSELF: 'TITLE'} for r_rel, rel_id in relation_map.items(): rel_names[rel_id] = r_rel with (SAVE_DIR / 'rel_names.pkl').open('wb') as f: pickle.dump(rel_names, f) print("Relation names saved.") if not skip_embeddings: with utils.work_in_progress("Saving entity vecs"): stacked_entity_vecs = torch.from_numpy(np.stack(mapped_entity_vecs)) torch.save(stacked_entity_vecs, SAVE_ENTITY_VEC_PATH) print(f"Entity vecs saved to {SAVE_ENTITY_VEC_PATH}, {len(stacked_entity_vecs)} vectors in total") with utils.work_in_progress("Saving relation vecs"): stacked_relation_vecs = torch.from_numpy(np.stack(mapped_relation_vecs)) torch.save(stacked_relation_vecs, SAVE_RELATION_VEC_PATH) print(f"Relation vecs saved to {SAVE_RELATION_VEC_PATH}, {len(stacked_relation_vecs)} vectors in total") else: print("Embedding updates skipped.") print("Processing done.")