def _process_page(self, page_title: str) -> tuple[int, int, int, int]: """ Processes a Wikipedia article and save to self.data_file Returns number of sequences """ sentences = self._get_sentence_features(page_title) # Construct features to be saved - word tokens, entities, and entity spans words = list() links: list[tuple[int, 3]] = list() n_seqs, n_ents, n_word_toks, n_words = 0, 0, 0, 0 TT.profile("Get features") for i, (sent_words, sent_links) in enumerate(sentences): links += [(id_, start + len(words), end + len(words)) for id_, start, end in sent_links] words += sent_words if i == len(sentences) - 1 or len(words) + len( sentences[i + 1][0]) > self.max_num_tokens: n_seqs += 1 # Save features for this sequence links = links[:self.max_entities] n_ents += len(links) word_ids = self.tokenizer.convert_tokens_to_ids(words) with TT.profile("Word spans"): word_spans = calculate_spans(words, self.tokenizer) assert self.min_sentence_length <= len( word_ids) <= self.max_num_tokens entity_ids = [id_ for id_, _, _ in links] entity_spans = [(start + 1, end + 1) for _, start, end in links] # Whether to mark doc. as part of validation set n_word_toks += len(word_ids) n_words += len(word_spans) self.examples.append({ "word_ids": [ self.tokenizer.cls_token_id, *word_ids, self.tokenizer.sep_token_id ], "word_spans": word_spans, "entity_ids": entity_ids, "entity_spans": entity_spans, "is_validation": False, }) words = list() links = list() TT.end_profile() return n_seqs, n_ents, n_word_toks, n_words
def build( cls, examples: list[Example], device: torch.device, word_mask_id: int, ent_mask_id: int, word_mask_prob: float, word_unmask_prob: float, word_randword_prob: float, word_id_range: tuple[int], ent_mask_prob: float, cut_extra_padding: bool = True, ): with TT.profile("Combine to batch"): words, entities = cls.collate(examples, device=torch.device("cpu"), cut=cut_extra_padding) with TT.profile("Mask words"): word_mask_labels, word_mask = mask_word_batch( words, word_mask_prob, word_unmask_prob, word_randword_prob, word_id_range, word_mask_id) with TT.profile("Mask entities"): ent_mask_labels, ent_mask = mask_ent_batch(entities, ent_mask_prob, ent_mask_id) with TT.profile("Send to %s" % device): words.ids = words.ids.to(device) words.attention_mask = words.attention_mask.to(device) words.N = words.N.to(device) entities.ids = entities.ids.to(device) entities.attention_mask = entities.attention_mask.to(device) entities.pos = entities.pos.to(device) entities.N = entities.N.to(device) word_mask_labels = word_mask_labels.to(device) word_mask = word_mask.to(device) ent_mask_labels = ent_mask_labels.to(device) ent_mask = ent_mask.to(device) return cls(words, entities, word_mask_labels, word_mask, ent_mask_labels, ent_mask)
def collate(self, batch: list[tuple[int, Example]]) -> MaskedBatchedExamples: with TT.profile("Build masked batch"): return MaskedBatchedExamples.build( [ex for _, ex in batch], self.device, word_mask_id=self.word_mask_id, ent_mask_id=self.ent_mask_id, word_mask_prob=self.word_mask_prob, word_unmask_prob=self.word_unmask_prob, word_randword_prob=self.word_randword_prob, word_id_range=self.random_word_id_range, ent_mask_prob=self.ent_mask_prob, cut_extra_padding=True, )
def __init__( self, data_dir: str, metadata: dict, entity_vocab: dict, device: torch.device, word_mask_prob: float, word_unmask_prob: float, word_randword_prob: float, ent_mask_prob: float, only_load_validation=False, vocab_size: int | None = None, token_map: np.ndarray | None = None, ent_min_mention: int = None, ): """ Loads a generated json dataset prepared by the preprocessing pipeline """ self.data_dir = data_dir self.metadata = metadata self.ent_ids = {info["id"] for info in entity_vocab.values()} self.ent_min_mention = ent_min_mention self.device = device self.max_sentence_len = metadata["max-seq-length"] self.max_entities = metadata["max-entities"] self.max_entity_span = metadata["max-entity-span"] self.word_mask_prob = word_mask_prob self.word_unmask_prob = word_unmask_prob self.word_randword_prob = word_randword_prob self.ent_mask_prob = ent_mask_prob self.only_load_validation = only_load_validation self.tokenizer = AutoTokenizer.from_pretrained(metadata["base-model"]) self.sep_id, self.cls_id, self.pad_id, self.word_mask_id, __ = get_special_ids( self.tokenizer) if token_map is not None: self.sep_id, self.cls_id, self.pad_id, self.word_mask_id = token_map[ [self.sep_id, self.cls_id, self.pad_id, self.word_mask_id]] self.ent_mask_id = entity_vocab[ENTITY_MASK_TOKEN]["id"] # Don't insert ids that are special tokens when performing random word insertion in the masking # The allowed range is dependant on the placement of special ids vocab_size = vocab_size or self.tokenizer.vocab_size self.random_word_id_range = (self.word_mask_id + 1, vocab_size)\ if self.word_mask_id < vocab_size-1 else\ (self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token)+1, vocab_size-1) with TT.profile("Build data"): self.train_examples, self.val_examples = self.build_examples()
def build_examples(self) -> tuple[list[Example], list[Example]]: train_examples, val_examples = list(), list() with open(os.path.join(self.data_dir, DatasetBuilder.data_file)) as f,\ TT.profile("Build example", hits=self.metadata["number-of-items"]): for seq_data in load_jsonl(f): is_validation = seq_data["is_validation"] if self.only_load_validation and not is_validation: continue if self.ent_min_mention: # Keep only entities in filtered entity vocab seq_data["entity_spans"] = [ span for id_, span in zip(seq_data["entity_ids"], seq_data["entity_spans"]) if id_ in self.ent_ids ] seq_data["entity_ids"] = [ id_ for id_ in seq_data["entity_ids"] if id_ in self.ent_ids ] ex = Example( words=Words.build( torch.IntTensor(seq_data["word_ids"]), seq_data["word_spans"], max_len=self.max_sentence_len, pad_id=self.pad_id, ), entities=Entities.build( torch.IntTensor(seq_data["entity_ids"]), seq_data["entity_spans"], max_entities=self.max_entities, max_entity_span=self.max_entity_span, ), ) if is_validation: val_examples.append(ex) else: train_examples.append(ex) return train_examples, val_examples
def train( rank: int, world_size: int, *, resume: bool, location: str, name: str, quiet: bool, save_every: int, validate_every: int, post_command: str, explicit_args: set[str], params: Hyperparams, ): # Get filepath within path context fpath = lambda path: os.path.join(location, path) if isinstance(path, str) else os.path.join(location, *path) # Setup multi-gpu if used setup(rank, world_size) is_master = rank < 1 # Are we on the main node? is_distributed = rank != -1 # Are we performing distributed computing? num_workers = torch.distributed.get_world_size() if is_distributed else 1 # Update locations TrainResults.subfolder = name Hyperparams.subfolder = name # Setup logger log.configure( os.path.join(location, name, "pretraining-worker=%s.log" % (rank if is_distributed else 0)), "DaLUKE pretraining on node %i" % rank, log_commit = True, print_level = (Levels.INFO if quiet else Levels.DEBUG) if is_master else None, append = resume, # Append to existing log file if we are resuming training ) post_time, post_command = parse_post_command(post_command) execute_post_command = False if post_time: log("Quitting in %.2f h and running command '%s'" % ((post_time-time.time())/3600, post_command)) if resume: log("Resuming from %s" % name) # Load results and hyperparameters from earlier training res = TrainResults.load(location) # Close unended profiles close_tt(res.tt) TT.fuse(res.tt) res.tt = TT tmp_saved_pu = res.parameter_update loaded_params = Hyperparams.load(location) # Overwrite ff-size if given explicitly if "ff_size" in explicit_args: loaded_params.ff_size = params.ff_size params = loaded_params else: tmp_saved_pu = None log.section("Starting pretraining with the following hyperparameters", params) log("Training using %i workers" % num_workers) log("Reading metadata and entity vocabulary") with open(fpath(DatasetBuilder.metadata_file)) as f: metadata = json.load(f) with open(fpath(DatasetBuilder.entity_vocab_file)) as f: entity_vocab = json.load(f) log("Loaded metadata:", json.dumps(metadata, indent=4)) log(f"Loaded entity vocabulary of {len(entity_vocab)} entities") if params.ent_min_mention: log("Removing entities with less than %i mentions" % params.ent_min_mention) entity_vocab = { ent: info for ent, info in entity_vocab.items() if info["count"] >= params.ent_min_mention or ent in {"[PAD]", "[UNK]", "[MASK]"} } log("After filtering, entity vocab now has %i entities" % len(entity_vocab)) # Device should be cuda:rank or just cuda if single gpu, else cpu if is_distributed: device = torch.device("cuda", index=rank) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") log.debug( "Hardware for this worker:", "CPU: %s" % cpuinfo.get_cpu_info()["brand_raw"], "GPU: %s" % (torch.cuda.get_device_name(device) if torch.cuda.is_available() else "NA"), sep="\t\n", ) if params.entity_loss_weight: log("Setting up loss function with entity loss weighting") # Don't weigh special tokens weights = torch.Tensor([0, 0, 0, *(1 / info["count"] for info in entity_vocab.values() if info["count"])]).to(device) entity_criterion = nn.CrossEntropyLoss(weight=weights) else: log("Setting up loss function without entity loss weighting") entity_criterion = nn.CrossEntropyLoss() word_criterion = nn.CrossEntropyLoss() loss_calculator = lambda w, e: params.word_ent_weight * w + (1 - params.word_ent_weight) * e # Load dataset and training results bert_config = AutoConfig.from_pretrained(metadata["base-model"]) if metadata["reduced-vocab"]: token_map_file = fpath(DatasetBuilder.token_map_file) log("Loading token map from '%s'" % token_map_file) token_map = np.load(token_map_file) tokenizer = AutoTokenizer.from_pretrained(metadata["base-model"]) *__, unk_id = get_special_ids(tokenizer) token_reduction = token_map_to_token_reduction(token_map, unk_id) else: token_map = None log("Building dataset") data = DataLoader( location, metadata, entity_vocab, device, params.word_mask_prob, params.word_unmask_prob, params.word_randword_prob, params.ent_mask_prob, vocab_size=metadata["vocab-size"], token_map=token_map, ent_min_mention=params.ent_min_mention, ) sampler = (DistributedSampler if is_distributed else RandomSampler)(data.train_examples) log("Built %i examples" % len(data)) loader = data.get_dataloader(params.ff_size, sampler) val_loader = data.get_dataloader(params.ff_size, SequentialSampler(data.val_examples), validation=True) # Number of subbatches in each parameter update (batch) grad_accumulation_steps = params.batch_size // (params.ff_size * num_workers) # How many full batches can be made from the dataset batches_in_data = len(data) // params.batch_size log( "Parameter updates: %i" % params.parameter_updates, "Subbatches per parameter update: %i" % grad_accumulation_steps, "Subbatches generated: %i" % len(loader), "Batches needed to cover dataset: %i" % batches_in_data, ) if not resume: # Calculate parameter differences, when at least 20k examples have been seen paramdiff_every = ceil(MIN_EXAMPLES_PER_PARAMDIFF / params.batch_size) log("Recalculating parameter differences every %i'th parameter update" % paramdiff_every) top_k = [1, 3, 10] log("Calculating top %s accuracies" % top_k) if validate_every: val_updates = unique(np.array( np.arange(-1, params.parameter_updates, validate_every).tolist() + [params.parameter_updates-1] ))[1:] else: val_updates = np.array([], dtype=int) res = TrainResults( runtime = np.zeros(params.parameter_updates), lr = np.zeros(params.parameter_updates), parameter_update = 0, losses = np.zeros(params.parameter_updates), scaled_loss = np.zeros(params.parameter_updates), top_k = top_k, w_losses = np.zeros(params.parameter_updates), e_losses = np.zeros(params.parameter_updates), w_accuracies = np.zeros((params.parameter_updates, len(top_k))), e_accuracies = np.zeros((params.parameter_updates, len(top_k))), val_param_updates = val_updates, val_losses = np.zeros(len(val_updates)), val_w_losses = np.zeros(len(val_updates)), val_e_losses = np.zeros(len(val_updates)), val_w_accuracies = np.zeros((len(val_updates), len(top_k))), val_e_accuracies = np.zeros((len(val_updates), len(top_k))), paramdiff_every = paramdiff_every, groups_to_slices = None, # Set later orig_params = None, paramdiff_1 = None, luke_exclusive_params = None, # Set later att_mats_from_base = None, # Set later tt = TT, ) save_pus = set(range(-1, params.parameter_updates, save_every)).union({params.parameter_updates-1}) log("Saving model at parameter updates: %s" % sorted(save_pus), "Validating at parameter updates: %s" % res.val_param_updates.tolist()) # Build model, possibly by loading previous weights log.section("Setting up model") bert_config = AutoConfig.from_pretrained(metadata["base-model"]) if params.ent_hidden_size is None: params.ent_hidden_size = bert_config.hidden_size else: assert params.ent_hidden_size <= bert_config.hidden_size,\ "Entity hidden size (%i) cannot be larger than hidden size in '%s' (%i)" % ( params.hidden_size, metadata["base-model"], bert_config.hidden_size, ) log("Initializing model") model_cls = BertAttentionPretrainTaskDaLUKE if params.bert_attention else PretrainTaskDaLUKE model = model_cls( bert_config, ent_vocab_size = len(entity_vocab), ent_embed_size = params.ent_embed_size, ent_hidden_size = params.ent_hidden_size, ent_intermediate_size = params.ent_intermediate_size, ).to(device) bert_config.vocab_size = metadata["vocab-size"] log("Bert config", bert_config.to_json_string()) if params.lukeinit: log("Initializing weights in accordance with LUKE") model.apply(lambda module: model.init_weights(module, bert_config.initializer_range)) # Load parameters from base model if not params.no_base_model: log("Loading base model parameters") with TT.profile("Loading base model parameters"): base_model = AutoModelForPreTraining.from_pretrained(metadata["base-model"]) new_weights = load_base_model_weights( model, base_model.state_dict(), params.bert_attention, ) if metadata["reduced-vocab"]: log("Removing unneeded token weights") reduced_model = model_cls( bert_config, ent_vocab_size = len(entity_vocab), ent_embed_size = params.ent_embed_size, ent_hidden_size = params.ent_hidden_size, ent_intermediate_size = params.ent_intermediate_size, ).to(device) copy_with_reduced_state_dict(token_reduction, model, reduced_model) model = reduced_model else: new_weights = set(model.state_dict()) # Initialize self-attention query matrices to BERT word query matrices att_mat_keys = set() if not params.bert_attention and not params.no_base_model: log("Initializing new attention matrices with%s PCA" % ("" if params.pcainit else "out")) att_mat_keys = model.init_special_attention(params.pcainit, device) if not resume: res.luke_exclusive_params = new_weights res.att_mats_from_base = att_mat_keys if is_master: res.orig_params = all_params(model).cpu().numpy() log("Pretraining model initialized with %s parameters" % thousand_seps(len(model))) # Unfixes params at this parameter update unfix_base_model_params_pu = round(params.bert_fix_prop * params.parameter_updates) log("Unfixing base model params after %i parameter updates" % unfix_base_model_params_pu) if resume: mpath = fpath((TrainResults.subfolder, MODEL_OUT.format(i=res.parameter_update))) log("Loading model from '%s'" % mpath) model.load_state_dict(torch.load(mpath, map_location=device)) log(f"Resuming training saved at parameter update {res.parameter_update}") else: res.groups_to_slices, t = all_params_groups_to_slices(model, bert_config.num_hidden_layers) log("Parameter groups and positions", t) res.paramdiff_1 = { name: np.zeros(ceil(params.parameter_updates/res.paramdiff_every)) for name in res.groups_to_slices } if is_distributed: model = DDP(model, device_ids=[rank], find_unused_parameters=True) non_ddp_model = model.module if is_distributed else model log("Setting up optimizer, scaler, and learning rate scheduler") optimizer = get_optimizer(non_ddp_model, params.weight_decay, params.lr) scaler = amp.GradScaler() if params.fp16 else None scheduler = get_lr_scheduler( optimizer, int(params.warmup_prop * params.parameter_updates), params.parameter_updates, unfix_base_model_params_pu, ) if resume: optimizer.load_state_dict(torch.load(fpath((TrainResults.subfolder, OPTIMIZER_OUT.format(i=res.parameter_update))), map_location=device)) scheduler.load_state_dict(torch.load(fpath((TrainResults.subfolder, SCHEDULER_OUT.format(i=res.parameter_update))), map_location=device)) if params.fp16: scaler.load_state_dict(torch.load(fpath((TrainResults.subfolder, SCALER_OUT.format(i=res.parameter_update))), map_location=device)) res.parameter_update += 1 # We saved the data at pu i, but should now commence pu i+1 log.debug("Time distribution before starting training", TT) log_memory_stats(device) log.section(f"Training DaLUKE for {params.parameter_updates} parameter updates") model.zero_grad() # To avoid tracking of model parameter manipulation model.train() # Start with transfer learned weights locked fix_base_model_params(res.luke_exclusive_params, non_ddp_model, True) fixed_params = True # Save initial parameters if is_master and not resume: with TT.profile("Saving progress"): paths = save_training(location, params, model.module if is_distributed else model, res, optimizer, scheduler, scaler, -1) log.debug("Saved initial state to", *paths) batch_iter = iter(loader) for i in range(res.parameter_update, params.parameter_updates): TT.profile("Parameter update") res.parameter_update = i if i >= unfix_base_model_params_pu and fixed_params: log("Unfixing base model params") fix_base_model_params(res.luke_exclusive_params, model, False) fixed_params = False if is_distributed and i % batches_in_data == 0: sampler.set_epoch(i // batches_in_data) # Losses and accuracies for this parameter update t_loss, w_loss, e_loss, s_loss = 0, 0, 0, 0 w_accuracies = np.zeros((grad_accumulation_steps, len(res.top_k))) e_accuracies = np.zeros((grad_accumulation_steps, len(res.top_k))) # Loop over enough batches to make a parameter update for j in range(grad_accumulation_steps): TT.profile("Sub-batch") try: batch = next(batch_iter) except StopIteration: batch_iter = iter(loader) batch = next(batch_iter) TT.profile("FP and gradients") with amp.autocast() if params.fp16 else contextlib.ExitStack(): word_preds, ent_preds = model(batch) # Compute and backpropagate loss word_loss = word_criterion(word_preds, batch.word_mask_labels) ent_loss = entity_criterion(ent_preds, batch.ent_mask_labels) has_entities = not torch.isnan(ent_loss).item() ent_loss = torch.nan_to_num(ent_loss) loss = loss_calculator(word_loss, ent_loss) loss /= grad_accumulation_steps # Only sync parameters on grad updates, aka last pass of this loop with model.no_sync() if is_distributed and j < grad_accumulation_steps - 1 else contextlib.ExitStack(): if params.fp16: scaled_loss = scaler.scale(loss) scaled_loss.backward() s_loss += scaled_loss.item() else: loss.backward() t_loss += loss.item() w_loss += word_loss.item() / grad_accumulation_steps e_loss += ent_loss.item() / grad_accumulation_steps if has_entities else 0 if torch.cuda.is_available(): torch.cuda.synchronize(rank if is_distributed else None) TT.end_profile() # Save accuracy for statistics if is_master: with TT.profile("Training accuracy"): w_accuracies[j] = top_k_accuracy(batch.word_mask_labels, word_preds, res.top_k) e_accuracies[j] = top_k_accuracy(batch.ent_mask_labels, ent_preds, res.top_k) TT.end_profile() # Update model parameters with TT.profile("Parameter step"): if params.fp16: scaler.step(optimizer) scaler.update() else: optimizer.step() scheduler.step() model.zero_grad() # Calculate how much gradient has changed if is_master and i % res.paramdiff_every == 0: with torch.no_grad(), TT.profile("Parameter changes"): log.debug("Calculating parameter changes") orig_pars = torch.from_numpy(res.orig_params).to(device) current_pars = all_params(model.module if is_distributed else model) absdiff = torch.abs(current_pars-orig_pars) for blockname, slice_ in res.groups_to_slices.items(): j = i // res.paramdiff_every res.paramdiff_1[blockname][j] = absdiff[slice_].sum().item() del orig_pars, current_pars res.losses[i] = t_loss res.w_losses[i] = w_loss res.e_losses[i] = e_loss res.scaled_loss[i] = s_loss res.lr[i] = scheduler.get_last_lr()[0] res.w_accuracies[i] = np.mean(w_accuracies, axis=0) res.e_accuracies[i] = np.nanmean(e_accuracies, axis=0) res.runtime[i] = TT.end_profile() log.debug( "Performed parameter update %i / %i in %.2f s" % (i, params.parameter_updates-1, res.runtime[i]), f" Loss (total, word, entity, scaled): {t_loss:9.4f}, {w_loss:9.4f}, {e_loss:9.4f}, {s_loss:.4f}", f" Accuracy (word, entity): {100*res.w_accuracies[i, 0]:7.2f} %, {100*res.e_accuracies[i, 0]:7.2f} %", ) if i in res.val_param_updates and is_master: TT.profile("Model validation") log("Validating model") vi = res.val_param_updates.tolist().index(i) res.val_w_losses[vi], res.val_e_losses[vi], res.val_w_accuracies[vi], res.val_e_accuracies[vi] =\ validate_model(model, val_loader, word_criterion, entity_criterion, res.top_k) res.val_losses[vi] = loss_calculator(res.val_w_losses[vi], res.val_e_losses[vi]) log( "Validation loss:", " Total: %9.4f" % res.val_losses[vi], " Word: %9.4f" % res.val_w_losses[vi], " Entity: %9.4f" % res.val_e_losses[vi], "Validation accuracy:", " Word: %7.2f %%" % (100 * res.val_w_accuracies[vi, 0]), " Entity: %7.2f %%" % (100 * res.val_e_accuracies[vi, 0]), ) model.train() TT.end_profile() log.debug("Time distribution so far", TT) # Save results and model if is_master and i in save_pus: with TT.profile("Saving progress"): save_progress(location, i, tmp_saved_pu, save_pus, params, model.module if is_distributed else model, res, optimizer, scheduler, scaler) if i in save_pus: log_memory_stats(device) # If timed out, save, quit, and run resume command if post_time and time.time() > post_time: log_memory_stats(device) log.section("Time limit reached. Quitting and running command '%s'" % post_command) with TT.profile("Saving progress"): save_progress(location, i, tmp_saved_pu, save_pus, params, model.module if is_distributed else model, res, optimizer, scheduler, scaler) execute_post_command = True break log.debug("Time distribution", TT) # Clean up multi-gpu if used cleanup(rank) if is_master and execute_post_command: os.system(post_command)
class DatasetBuilder: tokenizer_language = "da" # Files saved by the build method metadata_file = "metadata.json" entity_vocab_file = "entity-vocab.json" data_file = "data.jsonl" token_map_file = "token-map.npy" def __init__( self, dump_db_file: str, # Location of file build by build-dump-db tokenizer_name: str, # Tokenizer to use, e.g. Maltehb/danish-bert-botxo for Danish BERT entity_vocab_file: str, # Build by build-entity-vocab out_dir: str, # Where to put finished dataset. All contents will be removed before saving dataset validation_prob: float, # Chance of each finished document to be marked as part of validation set max_entities: int, # Only up to this many entities are included in each sequence max_entity_span: int, # Maximum number tokens an entity can span before sequence is discarded min_sentence_length: int, # Minimum number of tokens a sentence must span to be included max_articles: int | None, max_vocab_size: int, ): if not wikipedia2vec_available: raise ModuleNotFoundError( "Pretrain data generation requires installation of the optional requirement `wikipedia2vec`" ) log("Reading dump database at %s" % dump_db_file) self.dump_db = DumpDB(dump_db_file) log("Building tokeninizer: %s" % tokenizer_name) self.tokenizer_name = tokenizer_name self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) log("Building sentence tokenizer: %s" % self.tokenizer_language) self.sentence_tokenizer = ICUSentenceTokenizer(self.tokenizer_language) log("Loading entity vocab at %s" % entity_vocab_file) self.entity_vocab = load_entity_vocab(entity_vocab_file) # Make sure IDs on non-ignored entities are contiguous num = 0 for entity_info in self.entity_vocab.values(): entity_info["id"] = num num += 1 log("Entity vocab has size %i" % num) self.out_dir = out_dir self.data_file = os.path.join(self.out_dir, self.data_file) self.token_map_file = os.path.join(self.out_dir, self.token_map_file) self.max_seq_length = self.tokenizer.model_max_length self.validation_prob = validation_prob self.max_entities = max_entities self.max_entity_span = max_entity_span self.min_sentence_length = min_sentence_length # Get maximum number of tokens in a sequence excluding start and end tokens self.max_num_tokens = self.max_seq_length - 2 self.max_articles = max_articles self.vocab_size = self.tokenizer.vocab_size if max_vocab_size == -1 else min( max_vocab_size, max_vocab_size) # Filter titles so only real articles are included self.target_titles = list(self.dump_db.titles()) # Remove old datafile if it exists if os.path.isfile(self.data_file): log.debug("Removing old datafile '%s'" % self.data_file) os.remove(self.data_file) self.examples = list() def _tokenize(self, text: str, paragraph_text: str, idx: int) -> list[str]: if not text: return list() try: if isinstance(self.tokenizer, RobertaTokenizer): tokens = self.tokenizer.tokenize( text, add_prefix_space=idx == 0 or text.startswith(" ") or paragraph_text[idx - 1] == " ", ) else: tokens = self.tokenizer.tokenize(text) except KeyboardInterrupt: # Make sure program can be keyboard interrupted despite needing to catch BaseException raise except BaseException as e: # Catch an exception caused by rust panicking in the tokenizer log.warning( "Failed to tokenize text with exception '%s'\nText: '%s'" % (e, text)) return list() return tokens def build(self): log("Saving tokenizer config and word token config to '%s'" % self.out_dir) with open(path := os.path.join(self.out_dir, self.entity_vocab_file), "w", encoding="utf-8") as ev: log("Saving entity vocab to '%s'" % path) ujson.dump(self.entity_vocab, ev, indent=2) log.section("Processing %i pages" % len(self.target_titles[:self.max_articles])) n_seqs, n_ents, n_word_toks, n_words = 0, 0, 0, 0 for title in log.tqdm(tqdm(self.target_titles[:self.max_articles])): log("Processing %s" % title) with TT.profile("Process page"): s, e, nt, nw = self._process_page(title) n_seqs += s n_ents += e n_word_toks += nt n_words += nw log("Shuffling data") random.shuffle(self.examples) n_vals = int(self.validation_prob * len(self.examples)) for i in range(n_vals): self.examples[i]["is_validation"] = True # Save metadata metadata = { "number-of-items": n_seqs, "number-of-word-tokens": n_word_toks, "number-of-words": n_words, "number-of-entities": n_ents, "number-of-val-items": n_vals, "max-seq-length": self.max_seq_length, "max-entities": self.max_entities, "max-entity-span": self.max_entity_span, "min-sentence-length": self.min_sentence_length, "base-model": self.tokenizer_name, "tokenizer-class": self.tokenizer.__class__.__name__, "language": self.dump_db.language, "reduced-vocab": self.vocab_size < self.tokenizer.vocab_size, "vocab-size": self.vocab_size, } if self.vocab_size < self.tokenizer.vocab_size: log.section("Reducing token number") with TT.profile("Reduce token vocab"): token_map, metadata["vocab-size"] = self._reduce_tokens() with TT.profile("Rewrite dataset with new tokens"): self._update_tokens(token_map) with open(path := os.path.join(self.out_dir, self.metadata_file), "w") as f: log.section("Saving metadata to %s" % path) ujson.dump(metadata, f, indent=4)
def _get_sentence_features(self, page_title: str) -> list[tuple[list[str], 3]]: sentences = list() # Process by paragraph for paragraph in self.dump_db.get_paragraphs(page_title): paragraph_links: list[tuple[str, int, int]] = list() paragraph_text = paragraph.text # Get paragraph links # These are representated by three-tuples consisting of their title, start and end string positions TT.profile("Get links") for link in paragraph.wiki_links: link_title: str = self.dump_db.resolve_redirect(link.title) # Remove links to articles that are not included if ignore_title(link_title): paragraph_text = paragraph_text[:link.start]\ + " " * (link.end - link.start)\ + paragraph_text[link.end:] elif link_title in self.entity_vocab: paragraph_links.append((link_title, link.start, link.end)) paragraph_links = list(reversed(paragraph_links)) TT.end_profile() # Process by sentence TT.profile("Sentences") if paragraph_links: link_title, link_start, link_end = paragraph_links.pop() else: link_title, link_start, link_end = "", -1, -1 sent_spans = self.sentence_tokenizer.span_tokenize( paragraph_text.rstrip()) for sent_start, sent_end in sent_spans: current = sent_start sent_words = list() # Tokens in the given sentence sent_links = list( ) # Links in a given sentence in three-tuples: (id, start index, end index) too_large_tokens = False while link_start < sent_start: try: link_title, link_start, link_end = paragraph_links.pop( ) except IndexError: break while sent_start <= link_start and link_end <= sent_end: # Look for links that are within the tokenized sentence # If a link is found, the sentences are seperated across the link and tokenized TT.profile("Tokenize") text = paragraph_text[current:link_start] sent_words += self._tokenize(text, paragraph_text, current) link_text = paragraph_text[link_start:link_end] link_words = self._tokenize(link_text, paragraph_text, link_start) TT.end_profile() sent_links.append(( self.entity_vocab[link_title]["id"], len(sent_words), len(sent_words) + len(link_words), )) if sent_links[-1][2] - sent_links[-1][ 1] > self.max_entity_span: too_large_tokens = True break sent_words += link_words current = link_end try: link_title, link_start, link_end = paragraph_links.pop( ) except IndexError: break text = paragraph_text[current:sent_end] sent_words += self._tokenize(text, paragraph_text, current) if len(sent_words) >= self.min_sentence_length\ and len(sent_words) <= self.max_num_tokens\ and not too_large_tokens: sentences.append((sent_words, sent_links)) TT.end_profile() return sentences
"reduced-vocab": self.vocab_size < self.tokenizer.vocab_size, "vocab-size": self.vocab_size, } if self.vocab_size < self.tokenizer.vocab_size: log.section("Reducing token number") with TT.profile("Reduce token vocab"): token_map, metadata["vocab-size"] = self._reduce_tokens() with TT.profile("Rewrite dataset with new tokens"): self._update_tokens(token_map) with open(path := os.path.join(self.out_dir, self.metadata_file), "w") as f: log.section("Saving metadata to %s" % path) ujson.dump(metadata, f, indent=4) with open(self.data_file, "w") as f, TT.profile("Save data"): log("Saving data to '%s'" % self.data_file) for example in self.examples: f.write(ujson.dumps(example) + "\n") log.debug("Time distribution", TT) def _get_sentence_features(self, page_title: str) -> list[tuple[list[str], 3]]: sentences = list() # Process by paragraph for paragraph in self.dump_db.get_paragraphs(page_title): paragraph_links: list[tuple[str, int, int]] = list() paragraph_text = paragraph.text
def evaluate_ner(model: nn.Module, dataloader: torch.utils.data.DataLoader, dataset: NERDataset, device: torch.device, split: Split, also_no_misc=True) -> NER_Results: model.eval() annotations, texts = dataset.data[split].annotations, dataset.data[ split].texts span_probs: list[dict[tuple[int, int], np.ndarray]] = list(dict() for _ in range(len(texts))) log.debug(f"Forward passing {len(dataloader)} batches") TT.tick() for batch in tqdm(dataloader): scores = model(batch) probs = F.softmax(scores, dim=2) # We save probability distribution, for every possible span in the example for idx, (i, spans) in zip(batch.text_nums, enumerate(batch.entities.fullword_spans)): span_probs[idx].update({ span: probs[i, j].detach().cpu().numpy() for j, span in enumerate(spans) if span }) preds = [ span_probs_to_preds(p, len(t), dataset) for p, t in zip(span_probs, texts) ] log(f"Forward pass completed: Wall time: {TT.tock():.4f} s.") stats = _stats_to_py_nums( classification_report(annotations, preds, output_dict=True, zero_division=0)) log(classification_report(annotations, preds, zero_division=0, digits=4)) confmat = confusion_matrix(annotations, preds, dataset.all_labels) confmat_nomisc = dict() log("Prediction distribution", _format_confmat(confmat)) if also_no_misc: #FIXME: Do this manually instead of rerunning everything stats_nomisc = _stats_to_py_nums( classification_report(_rm_misc(annotations, dataset.null_label), _rm_misc(preds, dataset.null_label), output_dict=True)) log( classification_report(_rm_misc(annotations, dataset.null_label), _rm_misc(preds, dataset.null_label), digits=4)) confmat_nomisc = confusion_matrix( _rm_misc(annotations, dataset.null_label), _rm_misc(preds, dataset.null_label), dataset.all_labels) log("Prediction distribution", _format_confmat(confmat)) return NER_Results( preds=preds, span_probs=span_probs, statistics=stats, statistics_nomisc=stats_nomisc if also_no_misc else {}, confusion_matrix=confmat, confusion_matrix_nomisc=confmat_nomisc, )