Example #1
0
    def _process_page(self, page_title: str) -> tuple[int, int, int, int]:
        """ Processes a Wikipedia article and save to self.data_file
        Returns number of sequences """

        sentences = self._get_sentence_features(page_title)

        # Construct features to be saved - word tokens, entities, and entity spans
        words = list()
        links: list[tuple[int, 3]] = list()
        n_seqs, n_ents, n_word_toks, n_words = 0, 0, 0, 0
        TT.profile("Get features")
        for i, (sent_words, sent_links) in enumerate(sentences):
            links += [(id_, start + len(words), end + len(words))
                      for id_, start, end in sent_links]
            words += sent_words
            if i == len(sentences) - 1 or len(words) + len(
                    sentences[i + 1][0]) > self.max_num_tokens:
                n_seqs += 1
                # Save features for this sequence
                links = links[:self.max_entities]
                n_ents += len(links)
                word_ids = self.tokenizer.convert_tokens_to_ids(words)
                with TT.profile("Word spans"):
                    word_spans = calculate_spans(words, self.tokenizer)
                assert self.min_sentence_length <= len(
                    word_ids) <= self.max_num_tokens
                entity_ids = [id_ for id_, _, _ in links]
                entity_spans = [(start + 1, end + 1)
                                for _, start, end in links]
                # Whether to mark doc. as part of validation set
                n_word_toks += len(word_ids)
                n_words += len(word_spans)
                self.examples.append({
                    "word_ids": [
                        self.tokenizer.cls_token_id, *word_ids,
                        self.tokenizer.sep_token_id
                    ],
                    "word_spans":
                    word_spans,
                    "entity_ids":
                    entity_ids,
                    "entity_spans":
                    entity_spans,
                    "is_validation":
                    False,
                })
                words = list()
                links = list()
        TT.end_profile()

        return n_seqs, n_ents, n_word_toks, n_words
Example #2
0
 def build(
     cls,
     examples: list[Example],
     device: torch.device,
     word_mask_id: int,
     ent_mask_id: int,
     word_mask_prob: float,
     word_unmask_prob: float,
     word_randword_prob: float,
     word_id_range: tuple[int],
     ent_mask_prob: float,
     cut_extra_padding: bool = True,
 ):
     with TT.profile("Combine to batch"):
         words, entities = cls.collate(examples,
                                       device=torch.device("cpu"),
                                       cut=cut_extra_padding)
     with TT.profile("Mask words"):
         word_mask_labels, word_mask = mask_word_batch(
             words, word_mask_prob, word_unmask_prob, word_randword_prob,
             word_id_range, word_mask_id)
     with TT.profile("Mask entities"):
         ent_mask_labels, ent_mask = mask_ent_batch(entities, ent_mask_prob,
                                                    ent_mask_id)
     with TT.profile("Send to %s" % device):
         words.ids = words.ids.to(device)
         words.attention_mask = words.attention_mask.to(device)
         words.N = words.N.to(device)
         entities.ids = entities.ids.to(device)
         entities.attention_mask = entities.attention_mask.to(device)
         entities.pos = entities.pos.to(device)
         entities.N = entities.N.to(device)
         word_mask_labels = word_mask_labels.to(device)
         word_mask = word_mask.to(device)
         ent_mask_labels = ent_mask_labels.to(device)
         ent_mask = ent_mask.to(device)
     return cls(words, entities, word_mask_labels, word_mask,
                ent_mask_labels, ent_mask)
Example #3
0
 def collate(self, batch: list[tuple[int,
                                     Example]]) -> MaskedBatchedExamples:
     with TT.profile("Build masked batch"):
         return MaskedBatchedExamples.build(
             [ex for _, ex in batch],
             self.device,
             word_mask_id=self.word_mask_id,
             ent_mask_id=self.ent_mask_id,
             word_mask_prob=self.word_mask_prob,
             word_unmask_prob=self.word_unmask_prob,
             word_randword_prob=self.word_randword_prob,
             word_id_range=self.random_word_id_range,
             ent_mask_prob=self.ent_mask_prob,
             cut_extra_padding=True,
         )
Example #4
0
    def __init__(
        self,
        data_dir: str,
        metadata: dict,
        entity_vocab: dict,
        device: torch.device,
        word_mask_prob: float,
        word_unmask_prob: float,
        word_randword_prob: float,
        ent_mask_prob: float,
        only_load_validation=False,
        vocab_size: int | None = None,
        token_map: np.ndarray | None = None,
        ent_min_mention: int = None,
    ):
        """ Loads a generated json dataset prepared by the preprocessing pipeline """
        self.data_dir = data_dir
        self.metadata = metadata
        self.ent_ids = {info["id"] for info in entity_vocab.values()}
        self.ent_min_mention = ent_min_mention
        self.device = device

        self.max_sentence_len = metadata["max-seq-length"]
        self.max_entities = metadata["max-entities"]
        self.max_entity_span = metadata["max-entity-span"]

        self.word_mask_prob = word_mask_prob
        self.word_unmask_prob = word_unmask_prob
        self.word_randword_prob = word_randword_prob
        self.ent_mask_prob = ent_mask_prob
        self.only_load_validation = only_load_validation

        self.tokenizer = AutoTokenizer.from_pretrained(metadata["base-model"])
        self.sep_id, self.cls_id, self.pad_id, self.word_mask_id, __ = get_special_ids(
            self.tokenizer)
        if token_map is not None:
            self.sep_id, self.cls_id, self.pad_id, self.word_mask_id = token_map[
                [self.sep_id, self.cls_id, self.pad_id, self.word_mask_id]]
        self.ent_mask_id = entity_vocab[ENTITY_MASK_TOKEN]["id"]
        # Don't insert ids that are special tokens when performing random word insertion in the masking
        # The allowed range is dependant on the placement of special ids
        vocab_size = vocab_size or self.tokenizer.vocab_size
        self.random_word_id_range = (self.word_mask_id + 1, vocab_size)\
            if self.word_mask_id < vocab_size-1 else\
                (self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token)+1, vocab_size-1)

        with TT.profile("Build data"):
            self.train_examples, self.val_examples = self.build_examples()
Example #5
0
    def build_examples(self) -> tuple[list[Example], list[Example]]:
        train_examples, val_examples = list(), list()
        with open(os.path.join(self.data_dir, DatasetBuilder.data_file)) as f,\
            TT.profile("Build example", hits=self.metadata["number-of-items"]):
            for seq_data in load_jsonl(f):
                is_validation = seq_data["is_validation"]
                if self.only_load_validation and not is_validation:
                    continue
                if self.ent_min_mention:
                    # Keep only entities in filtered entity vocab
                    seq_data["entity_spans"] = [
                        span for id_, span in zip(seq_data["entity_ids"],
                                                  seq_data["entity_spans"])
                        if id_ in self.ent_ids
                    ]
                    seq_data["entity_ids"] = [
                        id_ for id_ in seq_data["entity_ids"]
                        if id_ in self.ent_ids
                    ]

                ex = Example(
                    words=Words.build(
                        torch.IntTensor(seq_data["word_ids"]),
                        seq_data["word_spans"],
                        max_len=self.max_sentence_len,
                        pad_id=self.pad_id,
                    ),
                    entities=Entities.build(
                        torch.IntTensor(seq_data["entity_ids"]),
                        seq_data["entity_spans"],
                        max_entities=self.max_entities,
                        max_entity_span=self.max_entity_span,
                    ),
                )
                if is_validation:
                    val_examples.append(ex)
                else:
                    train_examples.append(ex)

        return train_examples, val_examples
Example #6
0
def train(
    rank:           int,
    world_size:     int,
    *,
    resume:         bool,
    location:       str,
    name:           str,
    quiet:          bool,
    save_every:     int,
    validate_every: int,
    post_command:   str,
    explicit_args:  set[str],
    params:         Hyperparams,
):
    # Get filepath within path context
    fpath = lambda path: os.path.join(location, path) if isinstance(path, str) else os.path.join(location, *path)

    # Setup multi-gpu if used
    setup(rank, world_size)

    is_master = rank < 1  # Are we on the main node?
    is_distributed = rank != -1  # Are we performing distributed computing?
    num_workers = torch.distributed.get_world_size() if is_distributed else 1

    # Update locations
    TrainResults.subfolder = name
    Hyperparams.subfolder = name

    # Setup logger
    log.configure(
        os.path.join(location, name, "pretraining-worker=%s.log" % (rank if is_distributed else 0)),
        "DaLUKE pretraining on node %i" % rank,
        log_commit  = True,
        print_level = (Levels.INFO if quiet else Levels.DEBUG) if is_master else None,
        append      = resume,  # Append to existing log file if we are resuming training
    )

    post_time, post_command = parse_post_command(post_command)
    execute_post_command = False
    if post_time:
        log("Quitting in %.2f h and running command '%s'" % ((post_time-time.time())/3600, post_command))

    if resume:
        log("Resuming from %s" % name)
        # Load results and hyperparameters from earlier training
        res = TrainResults.load(location)
        # Close unended profiles
        close_tt(res.tt)
        TT.fuse(res.tt)
        res.tt = TT
        tmp_saved_pu = res.parameter_update
        loaded_params = Hyperparams.load(location)
        # Overwrite ff-size if given explicitly
        if "ff_size" in explicit_args:
            loaded_params.ff_size = params.ff_size
        params = loaded_params
    else:
        tmp_saved_pu = None
    log.section("Starting pretraining with the following hyperparameters", params)
    log("Training using %i workers" % num_workers)

    log("Reading metadata and entity vocabulary")
    with open(fpath(DatasetBuilder.metadata_file)) as f:
        metadata = json.load(f)
    with open(fpath(DatasetBuilder.entity_vocab_file)) as f:
        entity_vocab = json.load(f)
    log("Loaded metadata:", json.dumps(metadata, indent=4))
    log(f"Loaded entity vocabulary of {len(entity_vocab)} entities")
    if params.ent_min_mention:
        log("Removing entities with less than %i mentions" % params.ent_min_mention)
        entity_vocab = { ent: info for ent, info in entity_vocab.items()
            if info["count"] >= params.ent_min_mention or ent in {"[PAD]", "[UNK]", "[MASK]"} }
        log("After filtering, entity vocab now has %i entities" % len(entity_vocab))

    # Device should be cuda:rank or just cuda if single gpu, else cpu
    if is_distributed:
        device = torch.device("cuda", index=rank)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log.debug(
        "Hardware for this worker:",
        "CPU: %s" % cpuinfo.get_cpu_info()["brand_raw"],
        "GPU: %s" % (torch.cuda.get_device_name(device) if torch.cuda.is_available() else "NA"),
        sep="\t\n",
    )

    if params.entity_loss_weight:
        log("Setting up loss function with entity loss weighting")
        # Don't weigh special tokens
        weights = torch.Tensor([0, 0, 0, *(1 / info["count"] for info in entity_vocab.values() if info["count"])]).to(device)
        entity_criterion = nn.CrossEntropyLoss(weight=weights)
    else:
        log("Setting up loss function without entity loss weighting")
        entity_criterion = nn.CrossEntropyLoss()
    word_criterion = nn.CrossEntropyLoss()
    loss_calculator = lambda w, e: params.word_ent_weight * w + (1 - params.word_ent_weight) * e

    # Load dataset and training results
    bert_config = AutoConfig.from_pretrained(metadata["base-model"])
    if metadata["reduced-vocab"]:
        token_map_file = fpath(DatasetBuilder.token_map_file)
        log("Loading token map from '%s'" % token_map_file)
        token_map = np.load(token_map_file)
        tokenizer = AutoTokenizer.from_pretrained(metadata["base-model"])
        *__, unk_id = get_special_ids(tokenizer)
        token_reduction = token_map_to_token_reduction(token_map, unk_id)
    else:
        token_map = None

    log("Building dataset")
    data = DataLoader(
        location,
        metadata,
        entity_vocab,
        device,
        params.word_mask_prob,
        params.word_unmask_prob,
        params.word_randword_prob,
        params.ent_mask_prob,
        vocab_size=metadata["vocab-size"],
        token_map=token_map,
        ent_min_mention=params.ent_min_mention,
    )
    sampler = (DistributedSampler if is_distributed else RandomSampler)(data.train_examples)
    log("Built %i examples" % len(data))

    loader = data.get_dataloader(params.ff_size, sampler)
    val_loader = data.get_dataloader(params.ff_size, SequentialSampler(data.val_examples), validation=True)

    # Number of subbatches in each parameter update (batch)
    grad_accumulation_steps = params.batch_size // (params.ff_size * num_workers)
    # How many full batches can be made from the dataset
    batches_in_data = len(data) // params.batch_size
    log(
        "Parameter updates:               %i" % params.parameter_updates,
        "Subbatches per parameter update: %i" % grad_accumulation_steps,
        "Subbatches generated:            %i" % len(loader),
        "Batches needed to cover dataset: %i" % batches_in_data,
    )

    if not resume:
        # Calculate parameter differences, when at least 20k examples have been seen
        paramdiff_every = ceil(MIN_EXAMPLES_PER_PARAMDIFF / params.batch_size)
        log("Recalculating parameter differences every %i'th parameter update" % paramdiff_every)
        top_k = [1, 3, 10]
        log("Calculating top %s accuracies" % top_k)
        if validate_every:
            val_updates = unique(np.array(
                np.arange(-1, params.parameter_updates, validate_every).tolist() + [params.parameter_updates-1]
            ))[1:]
        else:
            val_updates = np.array([], dtype=int)
        res = TrainResults(
            runtime           = np.zeros(params.parameter_updates),
            lr                = np.zeros(params.parameter_updates),
            parameter_update  = 0,

            losses            = np.zeros(params.parameter_updates),
            scaled_loss       = np.zeros(params.parameter_updates),

            top_k             = top_k,
            w_losses          = np.zeros(params.parameter_updates),
            e_losses          = np.zeros(params.parameter_updates),
            w_accuracies      = np.zeros((params.parameter_updates, len(top_k))),
            e_accuracies      = np.zeros((params.parameter_updates, len(top_k))),

            val_param_updates = val_updates,
            val_losses        = np.zeros(len(val_updates)),
            val_w_losses      = np.zeros(len(val_updates)),
            val_e_losses      = np.zeros(len(val_updates)),
            val_w_accuracies  = np.zeros((len(val_updates), len(top_k))),
            val_e_accuracies  = np.zeros((len(val_updates), len(top_k))),

            paramdiff_every   = paramdiff_every,
            groups_to_slices  = None,  # Set later
            orig_params       = None,
            paramdiff_1       = None,

            luke_exclusive_params = None,  # Set later
            att_mats_from_base    = None,  # Set later

            tt = TT,
        )

    save_pus = set(range(-1, params.parameter_updates, save_every)).union({params.parameter_updates-1})
    log("Saving model at parameter updates: %s" % sorted(save_pus),
        "Validating at parameter updates: %s" % res.val_param_updates.tolist())

    # Build model, possibly by loading previous weights
    log.section("Setting up model")
    bert_config = AutoConfig.from_pretrained(metadata["base-model"])
    if params.ent_hidden_size is None:
        params.ent_hidden_size = bert_config.hidden_size
    else:
        assert params.ent_hidden_size <= bert_config.hidden_size,\
            "Entity hidden size (%i) cannot be larger than hidden size in '%s' (%i)" % (
                params.hidden_size,
                metadata["base-model"],
                bert_config.hidden_size,
            )

    log("Initializing model")
    model_cls = BertAttentionPretrainTaskDaLUKE if params.bert_attention else PretrainTaskDaLUKE
    model = model_cls(
        bert_config,
        ent_vocab_size        = len(entity_vocab),
        ent_embed_size        = params.ent_embed_size,
        ent_hidden_size       = params.ent_hidden_size,
        ent_intermediate_size = params.ent_intermediate_size,
    ).to(device)
    bert_config.vocab_size = metadata["vocab-size"]
    log("Bert config", bert_config.to_json_string())

    if params.lukeinit:
        log("Initializing weights in accordance with LUKE")
        model.apply(lambda module: model.init_weights(module, bert_config.initializer_range))
    # Load parameters from base model
    if not params.no_base_model:
        log("Loading base model parameters")
        with TT.profile("Loading base model parameters"):
            base_model = AutoModelForPreTraining.from_pretrained(metadata["base-model"])
            new_weights = load_base_model_weights(
                model,
                base_model.state_dict(),
                params.bert_attention,
            )
            if metadata["reduced-vocab"]:
                log("Removing unneeded token weights")
                reduced_model = model_cls(
                    bert_config,
                    ent_vocab_size        = len(entity_vocab),
                    ent_embed_size        = params.ent_embed_size,
                    ent_hidden_size       = params.ent_hidden_size,
                    ent_intermediate_size = params.ent_intermediate_size,
                ).to(device)
                copy_with_reduced_state_dict(token_reduction, model, reduced_model)
                model = reduced_model
    else:
        new_weights = set(model.state_dict())
    # Initialize self-attention query matrices to BERT word query matrices
    att_mat_keys = set()
    if not params.bert_attention and not params.no_base_model:
        log("Initializing new attention matrices with%s PCA" % ("" if params.pcainit else "out"))
        att_mat_keys = model.init_special_attention(params.pcainit, device)
    if not resume:
        res.luke_exclusive_params = new_weights
        res.att_mats_from_base = att_mat_keys
        if is_master:
            res.orig_params = all_params(model).cpu().numpy()
    log("Pretraining model initialized with %s parameters" % thousand_seps(len(model)))

    # Unfixes params at this parameter update
    unfix_base_model_params_pu = round(params.bert_fix_prop * params.parameter_updates)
    log("Unfixing base model params after %i parameter updates" % unfix_base_model_params_pu)

    if resume:
        mpath = fpath((TrainResults.subfolder, MODEL_OUT.format(i=res.parameter_update)))
        log("Loading model from '%s'" % mpath)
        model.load_state_dict(torch.load(mpath, map_location=device))
        log(f"Resuming training saved at parameter update {res.parameter_update}")
    else:
        res.groups_to_slices, t = all_params_groups_to_slices(model, bert_config.num_hidden_layers)
        log("Parameter groups and positions", t)
        res.paramdiff_1 = { name: np.zeros(ceil(params.parameter_updates/res.paramdiff_every)) for name in res.groups_to_slices }
    if is_distributed:
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    non_ddp_model = model.module if is_distributed else model

    log("Setting up optimizer, scaler, and learning rate scheduler")
    optimizer = get_optimizer(non_ddp_model, params.weight_decay, params.lr)
    scaler = amp.GradScaler() if params.fp16 else None
    scheduler = get_lr_scheduler(
        optimizer,
        int(params.warmup_prop * params.parameter_updates),
        params.parameter_updates,
        unfix_base_model_params_pu,
    )
    if resume:
        optimizer.load_state_dict(torch.load(fpath((TrainResults.subfolder, OPTIMIZER_OUT.format(i=res.parameter_update))), map_location=device))
        scheduler.load_state_dict(torch.load(fpath((TrainResults.subfolder, SCHEDULER_OUT.format(i=res.parameter_update))), map_location=device))
        if params.fp16:
            scaler.load_state_dict(torch.load(fpath((TrainResults.subfolder, SCALER_OUT.format(i=res.parameter_update))), map_location=device))
        res.parameter_update += 1  # We saved the data at pu i, but should now commence pu i+1

    log.debug("Time distribution before starting training", TT)
    log_memory_stats(device)

    log.section(f"Training DaLUKE for {params.parameter_updates} parameter updates")
    model.zero_grad()  # To avoid tracking of model parameter manipulation
    model.train()

    # Start with transfer learned weights locked
    fix_base_model_params(res.luke_exclusive_params, non_ddp_model, True)
    fixed_params = True

    # Save initial parameters
    if is_master and not resume:
        with TT.profile("Saving progress"):
            paths = save_training(location, params, model.module if is_distributed else model,
                res, optimizer, scheduler, scaler, -1)
            log.debug("Saved initial state to", *paths)

    batch_iter = iter(loader)
    for i in range(res.parameter_update, params.parameter_updates):
        TT.profile("Parameter update")
        res.parameter_update = i
        if i >= unfix_base_model_params_pu and fixed_params:
            log("Unfixing base model params")
            fix_base_model_params(res.luke_exclusive_params, model, False)
            fixed_params = False
        if is_distributed and i % batches_in_data == 0:
            sampler.set_epoch(i // batches_in_data)

        # Losses and accuracies for this parameter update
        t_loss, w_loss, e_loss, s_loss = 0, 0, 0, 0
        w_accuracies = np.zeros((grad_accumulation_steps, len(res.top_k)))
        e_accuracies = np.zeros((grad_accumulation_steps, len(res.top_k)))

        # Loop over enough batches to make a parameter update
        for j in range(grad_accumulation_steps):
            TT.profile("Sub-batch")
            try:
                batch = next(batch_iter)
            except StopIteration:
                batch_iter = iter(loader)
                batch = next(batch_iter)

            TT.profile("FP and gradients")
            with amp.autocast() if params.fp16 else contextlib.ExitStack():
                word_preds, ent_preds = model(batch)
                # Compute and backpropagate loss
                word_loss = word_criterion(word_preds, batch.word_mask_labels)
                ent_loss = entity_criterion(ent_preds, batch.ent_mask_labels)
                has_entities = not torch.isnan(ent_loss).item()
                ent_loss = torch.nan_to_num(ent_loss)
            loss = loss_calculator(word_loss, ent_loss)
            loss /= grad_accumulation_steps

            # Only sync parameters on grad updates, aka last pass of this loop
            with model.no_sync() if is_distributed and j < grad_accumulation_steps - 1 else contextlib.ExitStack():
                if params.fp16:
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    s_loss += scaled_loss.item()
                else:
                    loss.backward()

            t_loss += loss.item()
            w_loss += word_loss.item() / grad_accumulation_steps
            e_loss += ent_loss.item() / grad_accumulation_steps if has_entities else 0

            if torch.cuda.is_available():
                torch.cuda.synchronize(rank if is_distributed else None)

            TT.end_profile()

            # Save accuracy for statistics
            if is_master:
                with TT.profile("Training accuracy"):
                    w_accuracies[j] = top_k_accuracy(batch.word_mask_labels, word_preds, res.top_k)
                    e_accuracies[j] = top_k_accuracy(batch.ent_mask_labels, ent_preds, res.top_k)

            TT.end_profile()

        # Update model parameters
        with TT.profile("Parameter step"):
            if params.fp16:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            model.zero_grad()

        # Calculate how much gradient has changed
        if is_master and i % res.paramdiff_every == 0:
            with torch.no_grad(), TT.profile("Parameter changes"):
                log.debug("Calculating parameter changes")
                orig_pars = torch.from_numpy(res.orig_params).to(device)
                current_pars = all_params(model.module if is_distributed else model)
                absdiff = torch.abs(current_pars-orig_pars)
                for blockname, slice_ in res.groups_to_slices.items():
                    j = i // res.paramdiff_every
                    res.paramdiff_1[blockname][j] = absdiff[slice_].sum().item()
                del orig_pars, current_pars

        res.losses[i]       = t_loss
        res.w_losses[i]     = w_loss
        res.e_losses[i]     = e_loss
        res.scaled_loss[i]  = s_loss
        res.lr[i]           = scheduler.get_last_lr()[0]
        res.w_accuracies[i] = np.mean(w_accuracies, axis=0)
        res.e_accuracies[i] = np.nanmean(e_accuracies, axis=0)
        res.runtime[i]      = TT.end_profile()
        log.debug(
            "Performed parameter update %i / %i in %.2f s" % (i, params.parameter_updates-1, res.runtime[i]),
            f"  Loss (total, word, entity, scaled): {t_loss:9.4f}, {w_loss:9.4f}, {e_loss:9.4f}, {s_loss:.4f}",
            f"  Accuracy (word, entity): {100*res.w_accuracies[i, 0]:7.2f} %, {100*res.e_accuracies[i, 0]:7.2f} %",
        )

        if i in res.val_param_updates and is_master:
            TT.profile("Model validation")
            log("Validating model")
            vi = res.val_param_updates.tolist().index(i)
            res.val_w_losses[vi], res.val_e_losses[vi], res.val_w_accuracies[vi], res.val_e_accuracies[vi] =\
                validate_model(model, val_loader, word_criterion, entity_criterion, res.top_k)
            res.val_losses[vi] = loss_calculator(res.val_w_losses[vi], res.val_e_losses[vi])
            log(
                "Validation loss:",
                "  Total:  %9.4f" % res.val_losses[vi],
                "  Word:   %9.4f" % res.val_w_losses[vi],
                "  Entity: %9.4f" % res.val_e_losses[vi],
                "Validation accuracy:",
                "  Word:   %7.2f %%" % (100 * res.val_w_accuracies[vi, 0]),
                "  Entity: %7.2f %%" % (100 * res.val_e_accuracies[vi, 0]),
            )
            model.train()
            TT.end_profile()
            log.debug("Time distribution so far", TT)

        # Save results and model
        if is_master and i in save_pus:
            with TT.profile("Saving progress"):
                save_progress(location, i, tmp_saved_pu, save_pus, params,
                    model.module if is_distributed else model, res, optimizer, scheduler, scaler)
        if i in save_pus:
            log_memory_stats(device)

        # If timed out, save, quit, and run resume command
        if post_time and time.time() > post_time:
            log_memory_stats(device)
            log.section("Time limit reached. Quitting and running command '%s'" % post_command)
            with TT.profile("Saving progress"):
                save_progress(location, i, tmp_saved_pu, save_pus, params,
                    model.module if is_distributed else model, res, optimizer, scheduler, scaler)
            execute_post_command = True
            break

    log.debug("Time distribution", TT)

    # Clean up multi-gpu if used
    cleanup(rank)

    if is_master and execute_post_command:
        os.system(post_command)
Example #7
0
class DatasetBuilder:

    tokenizer_language = "da"

    # Files saved by the build method
    metadata_file = "metadata.json"
    entity_vocab_file = "entity-vocab.json"
    data_file = "data.jsonl"
    token_map_file = "token-map.npy"

    def __init__(
        self,
        dump_db_file: str,  # Location of file build by build-dump-db
        tokenizer_name:
        str,  # Tokenizer to use, e.g. Maltehb/danish-bert-botxo for Danish BERT
        entity_vocab_file: str,  # Build by build-entity-vocab
        out_dir:
        str,  # Where to put finished dataset. All contents will be removed before saving dataset
        validation_prob:
        float,  # Chance of each finished document to be marked as part of validation set
        max_entities:
        int,  # Only up to this many entities are included in each sequence
        max_entity_span:
        int,  # Maximum number tokens an entity can span before sequence is discarded
        min_sentence_length:
        int,  # Minimum number of tokens a sentence must span to be included
        max_articles: int | None,
        max_vocab_size: int,
    ):
        if not wikipedia2vec_available:
            raise ModuleNotFoundError(
                "Pretrain data generation requires installation of the optional requirement `wikipedia2vec`"
            )
        log("Reading dump database at %s" % dump_db_file)
        self.dump_db = DumpDB(dump_db_file)
        log("Building tokeninizer: %s" % tokenizer_name)
        self.tokenizer_name = tokenizer_name
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        log("Building sentence tokenizer: %s" % self.tokenizer_language)
        self.sentence_tokenizer = ICUSentenceTokenizer(self.tokenizer_language)
        log("Loading entity vocab at %s" % entity_vocab_file)
        self.entity_vocab = load_entity_vocab(entity_vocab_file)
        # Make sure IDs on non-ignored entities are contiguous
        num = 0
        for entity_info in self.entity_vocab.values():
            entity_info["id"] = num
            num += 1
        log("Entity vocab has size %i" % num)

        self.out_dir = out_dir
        self.data_file = os.path.join(self.out_dir, self.data_file)
        self.token_map_file = os.path.join(self.out_dir, self.token_map_file)
        self.max_seq_length = self.tokenizer.model_max_length
        self.validation_prob = validation_prob
        self.max_entities = max_entities
        self.max_entity_span = max_entity_span
        self.min_sentence_length = min_sentence_length
        # Get maximum number of tokens in a sequence excluding start and end tokens
        self.max_num_tokens = self.max_seq_length - 2
        self.max_articles = max_articles
        self.vocab_size = self.tokenizer.vocab_size if max_vocab_size == -1 else min(
            max_vocab_size, max_vocab_size)

        # Filter titles so only real articles are included
        self.target_titles = list(self.dump_db.titles())

        # Remove old datafile if it exists
        if os.path.isfile(self.data_file):
            log.debug("Removing old datafile '%s'" % self.data_file)
            os.remove(self.data_file)

        self.examples = list()

    def _tokenize(self, text: str, paragraph_text: str, idx: int) -> list[str]:
        if not text:
            return list()
        try:
            if isinstance(self.tokenizer, RobertaTokenizer):
                tokens = self.tokenizer.tokenize(
                    text,
                    add_prefix_space=idx == 0 or text.startswith(" ")
                    or paragraph_text[idx - 1] == " ",
                )
            else:
                tokens = self.tokenizer.tokenize(text)
        except KeyboardInterrupt:
            # Make sure program can be keyboard interrupted despite needing to catch BaseException
            raise
        except BaseException as e:
            # Catch an exception caused by rust panicking in the tokenizer
            log.warning(
                "Failed to tokenize text with exception '%s'\nText: '%s'" %
                (e, text))
            return list()

        return tokens

    def build(self):
        log("Saving tokenizer config and word token config to '%s'" %
            self.out_dir)
        with open(path := os.path.join(self.out_dir, self.entity_vocab_file),
                  "w",
                  encoding="utf-8") as ev:
            log("Saving entity vocab to '%s'" % path)
            ujson.dump(self.entity_vocab, ev, indent=2)

        log.section("Processing %i pages" %
                    len(self.target_titles[:self.max_articles]))
        n_seqs, n_ents, n_word_toks, n_words = 0, 0, 0, 0
        for title in log.tqdm(tqdm(self.target_titles[:self.max_articles])):
            log("Processing %s" % title)
            with TT.profile("Process page"):
                s, e, nt, nw = self._process_page(title)
                n_seqs += s
                n_ents += e
                n_word_toks += nt
                n_words += nw

        log("Shuffling data")
        random.shuffle(self.examples)
        n_vals = int(self.validation_prob * len(self.examples))
        for i in range(n_vals):
            self.examples[i]["is_validation"] = True

        # Save metadata
        metadata = {
            "number-of-items": n_seqs,
            "number-of-word-tokens": n_word_toks,
            "number-of-words": n_words,
            "number-of-entities": n_ents,
            "number-of-val-items": n_vals,
            "max-seq-length": self.max_seq_length,
            "max-entities": self.max_entities,
            "max-entity-span": self.max_entity_span,
            "min-sentence-length": self.min_sentence_length,
            "base-model": self.tokenizer_name,
            "tokenizer-class": self.tokenizer.__class__.__name__,
            "language": self.dump_db.language,
            "reduced-vocab": self.vocab_size < self.tokenizer.vocab_size,
            "vocab-size": self.vocab_size,
        }

        if self.vocab_size < self.tokenizer.vocab_size:
            log.section("Reducing token number")
            with TT.profile("Reduce token vocab"):
                token_map, metadata["vocab-size"] = self._reduce_tokens()
            with TT.profile("Rewrite dataset with new tokens"):
                self._update_tokens(token_map)

        with open(path := os.path.join(self.out_dir, self.metadata_file),
                  "w") as f:
            log.section("Saving metadata to %s" % path)
            ujson.dump(metadata, f, indent=4)
Example #8
0
    def _get_sentence_features(self,
                               page_title: str) -> list[tuple[list[str], 3]]:

        sentences = list()

        # Process by paragraph
        for paragraph in self.dump_db.get_paragraphs(page_title):
            paragraph_links: list[tuple[str, int, int]] = list()
            paragraph_text = paragraph.text

            # Get paragraph links
            # These are representated by three-tuples consisting of their title, start and end string positions
            TT.profile("Get links")
            for link in paragraph.wiki_links:
                link_title: str = self.dump_db.resolve_redirect(link.title)
                # Remove links to articles that are not included
                if ignore_title(link_title):
                    paragraph_text = paragraph_text[:link.start]\
                        + " " * (link.end - link.start)\
                        + paragraph_text[link.end:]
                elif link_title in self.entity_vocab:
                    paragraph_links.append((link_title, link.start, link.end))
            paragraph_links = list(reversed(paragraph_links))
            TT.end_profile()

            # Process by sentence
            TT.profile("Sentences")
            if paragraph_links:
                link_title, link_start, link_end = paragraph_links.pop()
            else:
                link_title, link_start, link_end = "", -1, -1

            sent_spans = self.sentence_tokenizer.span_tokenize(
                paragraph_text.rstrip())
            for sent_start, sent_end in sent_spans:
                current = sent_start
                sent_words = list()  # Tokens in the given sentence
                sent_links = list(
                )  # Links in a given sentence in three-tuples: (id, start index, end index)
                too_large_tokens = False

                while link_start < sent_start:
                    try:
                        link_title, link_start, link_end = paragraph_links.pop(
                        )
                    except IndexError:
                        break

                while sent_start <= link_start and link_end <= sent_end:
                    # Look for links that are within the tokenized sentence
                    # If a link is found, the sentences are seperated across the link and tokenized
                    TT.profile("Tokenize")
                    text = paragraph_text[current:link_start]
                    sent_words += self._tokenize(text, paragraph_text, current)

                    link_text = paragraph_text[link_start:link_end]
                    link_words = self._tokenize(link_text, paragraph_text,
                                                link_start)
                    TT.end_profile()

                    sent_links.append((
                        self.entity_vocab[link_title]["id"],
                        len(sent_words),
                        len(sent_words) + len(link_words),
                    ))
                    if sent_links[-1][2] - sent_links[-1][
                            1] > self.max_entity_span:
                        too_large_tokens = True
                        break
                    sent_words += link_words
                    current = link_end

                    try:
                        link_title, link_start, link_end = paragraph_links.pop(
                        )
                    except IndexError:
                        break

                text = paragraph_text[current:sent_end]
                sent_words += self._tokenize(text, paragraph_text, current)

                if len(sent_words) >= self.min_sentence_length\
                    and len(sent_words) <= self.max_num_tokens\
                    and not too_large_tokens:
                    sentences.append((sent_words, sent_links))
            TT.end_profile()

        return sentences
Example #9
0
            "reduced-vocab": self.vocab_size < self.tokenizer.vocab_size,
            "vocab-size": self.vocab_size,
        }

        if self.vocab_size < self.tokenizer.vocab_size:
            log.section("Reducing token number")
            with TT.profile("Reduce token vocab"):
                token_map, metadata["vocab-size"] = self._reduce_tokens()
            with TT.profile("Rewrite dataset with new tokens"):
                self._update_tokens(token_map)

        with open(path := os.path.join(self.out_dir, self.metadata_file),
                  "w") as f:
            log.section("Saving metadata to %s" % path)
            ujson.dump(metadata, f, indent=4)
        with open(self.data_file, "w") as f, TT.profile("Save data"):
            log("Saving data to '%s'" % self.data_file)
            for example in self.examples:
                f.write(ujson.dumps(example) + "\n")

        log.debug("Time distribution", TT)

    def _get_sentence_features(self,
                               page_title: str) -> list[tuple[list[str], 3]]:

        sentences = list()

        # Process by paragraph
        for paragraph in self.dump_db.get_paragraphs(page_title):
            paragraph_links: list[tuple[str, int, int]] = list()
            paragraph_text = paragraph.text