def create_batches(self, batch_size: int, bptt_size: int): r"""A general routine to create batches of specified batch size and BPTT length. :param batch_size: The number of examples in one batch. :param bptt_size: The length for truncated-backprop, i.e. the maximum length of sentences in one batch. """ self.batches = {} self.ntokens = {} for split, raw_dataset in self.data.items(): ntokens = 0 # sort the data by document length parts = sorted(raw_dataset, key=len) num_batches = utils.ceil_div(len(parts), batch_size) all_batches = [] for batch_idx in utils.progress(num_batches, desc="Creating batches", ascii=True, ncols=80): part = parts[(batch_idx * batch_size):((batch_idx + 1) * batch_size)] init_batch, batches = self.create_one_batch(part, bptt_size) ntokens += sum(batch.ntokens for batch in batches) all_batches.append((init_batch, batches)) self.batches[split] = all_batches self.ntokens[split] = ntokens unk_probs = self.unk_probs if unk_probs is not None: total_w2i = self.total_w2i for split, dataset in self.batches.items(): dataset = utils.progress( dataset, ncols=80, desc=f"Adding unk vocab for {split} set", ascii=True) for _, batches in dataset: for batch in batches: batch.add_unk_probs(unk_probs, total_w2i)
def average_copies(example, start_symbol, max_length=200, n_tries=10, progress=False, show_samples=False, **kwargs): complete_copies = utils.SimpleAverage() incomplete_copies = utils.SimpleAverage() if show_samples: sample_kwargs = dict(print_info=True, color_outputs=True, color_incomplete=False, **kwargs) else: sample_kwargs = dict(print_info=False, color_outputs=False, **kwargs) for _ in utils.progress(n_tries, verbose=progress): output: SampledOutput = model.sampling_decode( dataset.vocab, example, begin_symbol=start_symbol, end_symbol=end_symbol, max_length=max_length, **sample_kwargs) complete_copies.add(output.complete_copies) incomplete_copies.add(output.incomplete_copies) if show_samples: print(' '.join(output.sentence)) if isinstance(example, LRLMExample): n_gold_rels = sum( int(span.end < max_length) for span in example.spans) print( f"Complete / Gold entities: {complete_copies.value()} / {n_gold_rels}" ) else: print( f"Complete / Incomplete entities: {complete_copies.value()} / {incomplete_copies.value()}" )
def read_data(path: Path) -> Dict[str, List[RawExampleWikiID]]: bad_examples: List[Tuple[str, int, str]] = [] data = {} for split in ['train', 'valid', 'test']: with (path / f'{split}.pkl').open('rb') as f: # relation tuple: (span, rel_type_desc, name, canonical_name) with utils.work_in_progress(f"Loading {split} set"): dump: List[RawDump] = pickle.load(f) examples = [] for idx, (sent, rels) in enumerate( utils.progress(dump, desc='Reading data')): # map (rel_typ, canonical) to list of aliases, since lists aren't hashable rel_to_alias: Dict[Tuple[str, str], List[str]] = \ {(rel[0][0], obj_id): alias for obj_id, _, _, rel, _, alias in rels} # sort it so the order is consistent relations: List[RelationWikiID] = sorted([ RelationWikiID(WikidataID(rel_id), WikidataID(obj_id), obj_alias) for (rel_id, obj_id), obj_alias in rel_to_alias.items() ]) rel_to_id: Dict[Tuple[str, str], int] = { (rel_id, obj_id): idx for idx, (rel_id, obj_id, obj_alias) in enumerate(relations) } # dedup to remove duplicate (-1, -1) mentions: List[EntityMention] = list( set( EntityMention(span, surface, rel_to_id[(rel_info[0][0], obj_id)]) for obj_id, head_id, span, rel_info, surface, _ in rels)) try: # must exist - head id with the relation: @TITLE@ is the topic WikidataID topic_id = next( head_id for _, head_id, _, rel_info, surface, alias in rels if rel_info[0][0] == "@TITLE@") except StopIteration: bad_examples.append((split, idx, ' '.join(sent)[:100])) continue converted_relations = [] for r in relations: converted_relations.append( RelationWikiID( TOPIC_ITSELF if r.rel_typ == "@TITLE@" else r.rel_typ, r.obj, r.obj_alias)) example = RawExampleWikiID(WikidataID(topic_id), sent, converted_relations, mentions) examples.append(example) data[split] = examples if len(bad_examples) > 0: Logging.warn(f"{len(bad_examples)} bad examples:\n" f"{pprint.pformat(bad_examples)}") else: Logging.verbose("All examples are good") return data
def _extra_init(self, loaded_batches: bool): self.rel_vocab = Vocab.from_dict(self._path / 'rel_names.pkl', mode='i2w') self.vocab: Dict[str, Vocab] = { "word": self.word_vocab, "rel": self.rel_vocab } self.max_unkrel = max( (-rel_typ - 3 for rel_typ in self.rel_vocab.i2w if rel_typ < -3), default=0) if self._use_fasttext: def _alias_path(name): path = Path(self._fasttext_model_path) return path.parent / (path.name + f'.{name}') # gather all entity aliases and compute fastText embeddings alias_dict_path = _alias_path('alias_dict.pkl') if alias_dict_path.exists(): alias_dict: Dict[str, int] = loadpkl(alias_dict_path) loaded = True else: alias_dict = defaultdict(lambda: len(alias_dict)) loaded = False if not loaded_batches: for dataset in self.data.values(): for example in dataset: for idx, rel in enumerate( example.relations): # type: ignore example.relations[ idx] = rel._replace( # type: ignore obj_alias=[ alias_dict[s] for s in rel.obj_alias ]) if not alias_dict_path.exists(): alias_dict = dict(alias_dict) savepkl(alias_dict, alias_dict_path) alias_vectors_path = _alias_path('alias_vectors.pt') if not alias_vectors_path.exists() or not loaded: import fastText ft_model = fastText.load_model(self._fasttext_model_path) alias_vectors = [] alias_list = utils.reverse_map(alias_dict) for alias in utils.progress(alias_list, desc="Building fastText vectors", ascii=True, ncols=80): vectors = [ ft_model.get_word_vector(w) for w in alias.split() ] vectors = np.sum(vectors, axis=0).tolist() alias_vectors.append(vectors) alias_vectors = torch.tensor(alias_vectors) torch.save(alias_vectors, alias_vectors_path) if not loaded_batches and (self._exclude_entity_disamb or self._exclude_alias_disamb): # no need to do this if batches are loaded if self._exclude_entity_disamb: # gather training set stats self.entity_count_per_type = self.gather_entity_stats( self.data['train']) for dataset in self.data.values(): for idx in range(len(dataset)): dataset[idx] = self.remove_ambiguity( dataset[idx], self._exclude_entity_disamb, self._exclude_alias_disamb)
def load_wikifacts(): """ Load the WikiFacts dataset """ id_name: Dict[FreebaseID, str] = {} id_summary: Dict[FreebaseID, str] = {} id_relations: Dict[FreebaseID, List[Relation]] = {} relation_types: Set[FreebaseID] = set() def add_or_check(fid: FreebaseID, name: str): name = name.lower() if fid in id_name: assert name == id_name[fid] else: id_name[fid] = name tar = tarfile.open(DATASET_PATH) id_fb_raw: Dict[FreebaseID, str] = {} id_en_raw: Dict[FreebaseID, str] = {} # Annotated Wikipedia summary (.sm) for summary_file in utils.progress(tar.getmembers(), desc='Reading summary (.sm)'): if not summary_file.name.endswith('.sm'): continue freebase_id = extract_id(summary_file) with tar.extractfile(summary_file) as f: id_summary[freebase_id] = f.read().decode('utf-8') # Cache Freebase Topic files (.fb, .en) for file in utils.progress(tar.getmembers(), desc='Caching topic files'): freebase_id = extract_id(file) if file.name.endswith('.fb'): with tar.extractfile(file) as f: id_fb_raw[freebase_id] = f.read().decode('utf-8') elif file.name.endswith('en'): with tar.extractfile(file) as f: id_en_raw[freebase_id] = f.read().decode('utf-8') # Freebase Topic (.fb, .en) for freebase_id in utils.progress(id_fb_raw, desc='Extracting relations'): relations = [] f_fb = id_fb_raw[freebase_id].split('\n') f_en = id_en_raw[freebase_id].split('\n') for rel_line, name_line in zip(f_fb, f_en): parts = rel_line.split() if len(parts) == 0: # empty line continue rel_name = name_line.strip().split(' ') # do not split on non-breaking or unicode spaces assert len(parts) == len(rel_name) if len(parts) == 5: # composite value type if parts[0].startswith('['): # subject is CVT continue parts = [parts[0], parts[3], parts[4].rstrip(']')] if parts[0] == parts[-1]: # don't include a relation with itself continue # remove the final ']', but do not use `.rstrip` because the name could contain ']' rel_name = [rel_name[0], rel_name[3], rel_name[4][:-1]] elif len(parts) == 3: # simple relation pass else: raise ValueError # malformed data rel = [remove_prefix(r) for r in parts] r_sub, r_rel, r_obj = rel if r_sub != freebase_id: # only keep relations whose subject matches topic continue add_or_check(r_sub, rel_name[0]) add_or_check(r_obj, rel_name[2]) relations.append((r_rel, r_obj)) relation_types.add(r_rel) id_relations[freebase_id] = relations # # Freebase to Wikidata mappings # freebase_wikidata_map = {} # with utils.progress(open(FREEBASE_MAPPING_PATH), desc='Reading fb2w') as f: # for line in f: # if line.startswith('#') or line.strip() == '': # continue # parts = line.split() # freebase_id = parts[0].split('/')[-1][:-1] # wikidata_id = parts[2].split('/')[-1][:-1] # freebase_wikidata_map[freebase_id] = wikidata_id # # # Check coverage # print(f"id_name coverage in fb2w: " # f"{sum(int(fid in freebase_wikidata_map) for fid in id_name)}/{len(id_name)}") # print(f"id_summary coverage in fb2w: " # f"{sum(int(fid in freebase_wikidata_map) for fid in id_summary)}/{len(id_summary)}") return id_name, id_summary, id_relations, relation_types
def main(): replace_canonical = any(arg.startswith('--replace') for arg in sys.argv) if not replace_canonical: # global SAVE_DIR # SAVE_DIR = Path('./data/wikifacts_orig/') pass else: print("Arguments: Will replace the canonical forms.") print(f"Output directory: {SAVE_DIR}") skip_embeddings = any(arg.startswith('--skip') for arg in sys.argv) if skip_embeddings: print("Arguments: Will skip embedding generation.") id_name, id_summary, id_relations, relation_types = load_wikifacts() entity_index, relation_index, entity_vec, relation_vec = load_openke_embeddings() # Check OpenKE coverage print(f"Entity coverage in OpenKE: " f"{sum(int(fid in entity_index) for fid in id_name)}/{len(id_name)}") print(f"Relation coverage in OpenKE: " f"{sum(int(fid in relation_index) for fid in relation_types)}/{len(relation_types)}") """ Match entity positions and generate pickled dataset """ # Remap entities and rel-types and store the mapping entity_map: Dict[FreebaseID, int] = {} relation_map: Dict[FreebaseID, int] = {} mapped_entity_vecs: List[np.ndarray] = [] mapped_relation_vecs: List[np.ndarray] = [] # noinspection PyShadowingNames def get_relation_id(rel: FreebaseID) -> int: rel_id = relation_map.get(rel, None) if rel_id is None: rel_id = relation_map[rel] = len(relation_map) mapped_relation_vecs.append(extract_vector(relation_vec, relation_index[rel])) return rel_id # noinspection PyShadowingNames def get_entity_id(entity: FreebaseID) -> int: ent_id = entity_map.get(entity, None) if ent_id is None: if entity not in entity_index: # not all covered return UNK_ENTITY ent_id = entity_map[entity] = len(entity_map) mapped_entity_vecs.append(extract_vector(entity_vec, entity_index[entity])) return ent_id # Create the dataset dataset: List[Example] = [] # mapping of topic ID to data example dataset_matched_spans: List[List[MatchedSpan]] = [] # mapping of topic ID to matched spans # noinspection PyShadowingNames def find_relation(rels: List[Tuple[FreebaseID, FreebaseID]], obj: FreebaseID) -> List[FreebaseID]: # when there are multiple matches, just find the first one matched_rels = [] for r_rel, r_obj in rels: if r_obj == obj: matched_rels.append(r_rel) return matched_rels # noinspection PyShadowingNames def match_positions(tokens: List[str], name: List[str]) -> List[int]: matched = [False] * len(name) positions = [] for idx, token in enumerate(tokens): for match_idx, match_token in enumerate(name): if matched[match_idx] or match_token != token: continue positions.append(match_idx) matched[match_idx] = True break else: return [] return positions tokenizer = sacremoses.MosesTokenizer(lang='en') position_stats = defaultdict(int) for freebase_id in utils.progress(id_summary, desc='Creating dataset'): topic_id = get_entity_id(freebase_id) summary = id_summary[freebase_id].strip().split(' ') raw_relations = id_relations[freebase_id] relations = defaultdict(lambda: len(relations)) rel_obj_names: Dict[int, str] = {} for r_rel, r_obj in raw_relations: rel_obj_names[relations[(get_relation_id(r_rel), get_entity_id(r_obj))]] = id_name[r_obj] topic_name = summary[0][2:summary[0].index('/')] if '/' in summary[0] else "<unknown topic name>" rel_obj_names[relations[(TOPIC_ITSELF, topic_id)]] = topic_name # topic_itself sentence, rel_ids, copy_pos, surface_indices = [], [], [], [] matched_spans: List[MatchedSpan] = [] def add_words(s: str): tokens = tokenizer.tokenize(s, escape=False) sentence.extend(tokens) rel_ids.extend([NAF] * len(tokens)) copy_pos.extend([-1] * len(tokens)) surface_indices.extend([-1] * len(tokens)) for word in summary: if '@@' in word: start_pos = word.find('@@') end_pos = word.find('@@', start_pos + 2) if start_pos > 0: add_words(word[:start_pos]) # leading stuff entity_desc = word[(start_pos + 2):end_pos].split('/') # there could be punctuation following '@@' assert len(entity_desc) >= 4 # entity name could contain '/' trailing = word[(end_pos + 2):] # trailing stuff entity_name = '/'.join(entity_desc[:-3]).split('_') r_obj = FreebaseID(entity_desc[-1]) obj_id = get_entity_id(r_obj) if entity_desc[-3] == 'f': if r_obj == freebase_id: # topic_itself rel_id = TOPIC_ITSELF matched_rels = [TOPIC_ITSELF] else: matched_rels = [get_relation_id(r_rel) for r_rel in find_relation(raw_relations, r_obj)] # the relation might not exist rel_id = None if len(matched_rels) == 0 else matched_rels[0] else: rel_id = ANCHOR matched_rels = [ANCHOR] # include anchors anyway and filter them out if we don't use them # matched_rels = [] # don't include anchors because we're only using them for our model if rel_id is None: # not anymore, we allow unknown stuff: if obj_id is None or rel_id is None: # no embedding for word, convert to normal token add_words(' '.join(entity_name) + trailing) else: position_stats['all'] += 1 if replace_canonical and r_obj in id_name: # basically, everything except anchors # just replace the entity_name, position will be in order canonical_name = id_name[r_obj].split('_') if canonical_name != entity_name: position_stats['not_canonical'] += 1 entity_name = canonical_name if entity_desc[-3] == 'f' and r_obj != freebase_id: positions = match_positions(entity_name, id_name[r_obj].split('_')) if len(positions) == 0: # cannot match, resort to replacing entity_name = id_name[r_obj].split('_') positions = list(range(len(entity_name))) position_stats['order'] += 1 else: position_stats['match'] += 1 if positions == list(range(len(entity_name))): position_stats['order'] += 1 position_stats['match_order'] += 1 elif positions == list(range(len(positions))): position_stats['match_prefix'] += 1 elif positions == list(range(positions[0], positions[-1] + 1)): position_stats['match_sub'] += 1 else: # we don't replace canonical names for anchors (because we don't have them) # anyway, these can't be used by our model if entity_desc[-3] == 'a': position_stats['anchor'] += 1 # add the anchor relation rel_obj_names[relations[(ANCHOR, obj_id)]] = ' '.join(entity_name) positions = list(range(len(entity_name))) # these may not be in `id_name`, use as is position_stats['order'] += 1 assert (rel_id, obj_id) in relations rel_idx = relations[(rel_id, obj_id)] # must exist for rel_typ in matched_rels: matched_spans.append((len(sentence), len(sentence) + len(entity_name), rel_typ, rel_idx, 0)) sentence.extend(entity_name) rel_ids.extend([rel_idx] * len(entity_name)) copy_pos.extend(positions) surface_indices.extend([0] * len(entity_name)) add_words(trailing) else: add_words(word) # we assume everything's canonical, so just add a pseudo canonical form rel_rev_map = {idx: rel for rel, idx in relations.items()} rel_list = [] # just in case `list(relations)` doesn't do as we expect for idx in range(len(rel_rev_map)): rel_id, obj_id = rel_rev_map[idx] rel_list.append((rel_id, obj_id, [rel_obj_names[idx].replace('_', ' ')])) example = (sentence, topic_id, rel_list, rel_ids, copy_pos, surface_indices) dataset.append(example) dataset_matched_spans.append(matched_spans) print(f"Position stats: {position_stats}") if replace_canonical: assert position_stats['all'] == position_stats['order'] # Save them directory: Path = SAVE_DATASET_PATH('train').parent if not directory.exists(): directory.mkdir(parents=True) # noinspection PyShadowingNames def split_dataset(dataset: List[T], splits: List[Tuple[int, int]]) -> List[List[T]]: dataset_size = len(dataset) dataset_splits = [] for l, r in splits: start = int(dataset_size * (l / 100)) end = int(dataset_size * (r / 100)) dataset_splits.append(dataset[start:end]) return dataset_splits splits = [(0, 80), (80, 90), (90, 100)] dataset_splits = split_dataset(dataset, splits) matched_spans_splits = split_dataset(dataset_matched_spans, splits) print(f"{len(dataset)} examples") for split, data, spans in zip(['train', 'valid', 'test'], dataset_splits, matched_spans_splits): path = SAVE_DATASET_PATH(split) with path.open('wb') as f: pickle.dump(data, f) print(f"Dataset split '{split}' saved to {path}, {len(data)} examples") path = SAVE_MATCHED_SPANS_PATH(split) with path.open('wb') as f: pickle.dump(spans, f) print(f"Matched spans split '{split}' saved to {path}") # save relation type names for use during generation rel_names: Dict[int, str] = {NAF: 'Not-A-Fact', ANCHOR: 'ANCHOR', TOPIC_ITSELF: 'TITLE'} for r_rel, rel_id in relation_map.items(): rel_names[rel_id] = r_rel with (SAVE_DIR / 'rel_names.pkl').open('wb') as f: pickle.dump(rel_names, f) print("Relation names saved.") if not skip_embeddings: with utils.work_in_progress("Saving entity vecs"): stacked_entity_vecs = torch.from_numpy(np.stack(mapped_entity_vecs)) torch.save(stacked_entity_vecs, SAVE_ENTITY_VEC_PATH) print(f"Entity vecs saved to {SAVE_ENTITY_VEC_PATH}, {len(stacked_entity_vecs)} vectors in total") with utils.work_in_progress("Saving relation vecs"): stacked_relation_vecs = torch.from_numpy(np.stack(mapped_relation_vecs)) torch.save(stacked_relation_vecs, SAVE_RELATION_VEC_PATH) print(f"Relation vecs saved to {SAVE_RELATION_VEC_PATH}, {len(stacked_relation_vecs)} vectors in total") else: print("Embedding updates skipped.") print("Processing done.")