Ejemplo n.º 1
0
    def __init__(
        self,
        base_model: str,
        max_seq_length: int,
        max_entities: int,
        max_entity_span: int,
        device: torch.device,
        token_map: Optional[np.ndarray],
    ):
        self.device = device
        self.max_seq_length = max_seq_length
        self.max_entities = max_entities
        self.max_entity_span = max_entity_span
        self.token_map = token_map

        self.label_to_idx = {
            label: i
            for i, label in enumerate(self.all_labels)
        }

        self.tokenizer = AutoTokenizer.from_pretrained(base_model)
        self.sep_id, self.cls_id, self.pad_id, self.mask_id, self.unk_id = get_special_ids(
            self.tokenizer)
        if self.token_map is not None:
            self.sep_id, self.cls_id, self.pad_id, self.mask_id, self.unk_id = self.token_map[
                [
                    self.sep_id, self.cls_id, self.pad_id, self.mask_id,
                    self.unk_id
                ]]

        # To be set by load method
        self.data: dict[Split, Sequences] = dict()
        self.loaded = False

        self.data_limit = None  # To be used for debugging
Ejemplo n.º 2
0
def example_from_str(
        text:             str,
        entity_spans:     list[tuple[int, int]],
        daluke:           AutoDaLUKE,
    ) -> BatchedExamples:
    subword_ids = get_subword_ids(text, daluke.tokenizer)
    sep, cls_, pad, _, _ = get_special_ids(daluke.tokenizer)
    flat_subword_ids = list(chain(*subword_ids))
    # Reduce the word ids to lower vocab if we use monoliguification of multilingual model
    if daluke.token_map is not None:
        flat_subword_ids = daluke.token_map[flat_subword_ids]

    w = Words.build(
        ids     = torch.IntTensor(flat_subword_ids),
        max_len = daluke.metadata["max-seq-length"],
        sep_id  = sep,
        cls_id  = cls_,
        pad_id  = pad,
    )
    e = Entities.build(
        ids             = get_entity_id_tensor(text, entity_spans, daluke.entity_vocab),
        spans           = get_entity_subword_spans(subword_ids, entity_spans),
        max_entities    = daluke.metadata["max-entities"],
        max_entity_span = daluke.metadata["max-entity-span"],
    )
    return BatchedExamples.build(
        [
            Example(
                words    = w,
                entities = e,
            ),
        ],
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    )
Ejemplo n.º 3
0
    def _reduce_tokens(self) -> tuple[np.ndarray, int]:
        token_counts = np.zeros(self.tokenizer.vocab_size, dtype=np.int32)

        log("Counting tokens in dataset")
        for example in tqdm(self.examples):
            word_ids = np.array(example["word_ids"])
            word_ids, counts = unique(word_ids, return_counts=True)
            token_counts[word_ids] += counts

        log("%i of %i tokens in the vocab are used" %
            ((token_counts > 0).sum(), self.tokenizer.vocab_size))
        *ids, unk_id = get_special_ids(self.tokenizer)
        unk_count = token_counts[unk_id]
        token_counts[[*ids, unk_id]] = -1
        sort_idx = np.argsort(token_counts)[::-1]
        keep_idx = sort_idx[:self.vocab_size]
        keep = np.zeros_like(token_counts, dtype=bool)
        keep[keep_idx] = True
        keep[[*ids, unk_id]] = True  # Always keep special tokens
        token_map = np.arange(self.tokenizer.vocab_size)
        token_map[~keep] = unk_id
        for i, j in enumerate(np.where(keep)[0]):
            token_map[j] = i
        log(
            "Reduced token vocabulary to %i tokens" % keep.sum(),
            "%.6f %% of word tokens in the dataset are now %s" % (
                100 * (unk_count + 1 + token_counts[~keep].sum()) /
                (unk_count + 1 + token_counts.sum()),
                self.tokenizer.unk_token,
            ),
        )
        np.save(self.token_map_file, token_map)
        log("Saved token map to '%s'" % self.token_map_file)

        return token_map, int(keep.sum())
Ejemplo n.º 4
0
    def __init__(
        self,
        data_dir: str,
        metadata: dict,
        entity_vocab: dict,
        device: torch.device,
        word_mask_prob: float,
        word_unmask_prob: float,
        word_randword_prob: float,
        ent_mask_prob: float,
        only_load_validation=False,
        vocab_size: int | None = None,
        token_map: np.ndarray | None = None,
        ent_min_mention: int = None,
    ):
        """ Loads a generated json dataset prepared by the preprocessing pipeline """
        self.data_dir = data_dir
        self.metadata = metadata
        self.ent_ids = {info["id"] for info in entity_vocab.values()}
        self.ent_min_mention = ent_min_mention
        self.device = device

        self.max_sentence_len = metadata["max-seq-length"]
        self.max_entities = metadata["max-entities"]
        self.max_entity_span = metadata["max-entity-span"]

        self.word_mask_prob = word_mask_prob
        self.word_unmask_prob = word_unmask_prob
        self.word_randword_prob = word_randword_prob
        self.ent_mask_prob = ent_mask_prob
        self.only_load_validation = only_load_validation

        self.tokenizer = AutoTokenizer.from_pretrained(metadata["base-model"])
        self.sep_id, self.cls_id, self.pad_id, self.word_mask_id, __ = get_special_ids(
            self.tokenizer)
        if token_map is not None:
            self.sep_id, self.cls_id, self.pad_id, self.word_mask_id = token_map[
                [self.sep_id, self.cls_id, self.pad_id, self.word_mask_id]]
        self.ent_mask_id = entity_vocab[ENTITY_MASK_TOKEN]["id"]
        # Don't insert ids that are special tokens when performing random word insertion in the masking
        # The allowed range is dependant on the placement of special ids
        vocab_size = vocab_size or self.tokenizer.vocab_size
        self.random_word_id_range = (self.word_mask_id + 1, vocab_size)\
            if self.word_mask_id < vocab_size-1 else\
                (self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token)+1, vocab_size-1)

        with TT.profile("Build data"):
            self.train_examples, self.val_examples = self.build_examples()
Ejemplo n.º 5
0
def train(
    rank:           int,
    world_size:     int,
    *,
    resume:         bool,
    location:       str,
    name:           str,
    quiet:          bool,
    save_every:     int,
    validate_every: int,
    post_command:   str,
    explicit_args:  set[str],
    params:         Hyperparams,
):
    # Get filepath within path context
    fpath = lambda path: os.path.join(location, path) if isinstance(path, str) else os.path.join(location, *path)

    # Setup multi-gpu if used
    setup(rank, world_size)

    is_master = rank < 1  # Are we on the main node?
    is_distributed = rank != -1  # Are we performing distributed computing?
    num_workers = torch.distributed.get_world_size() if is_distributed else 1

    # Update locations
    TrainResults.subfolder = name
    Hyperparams.subfolder = name

    # Setup logger
    log.configure(
        os.path.join(location, name, "pretraining-worker=%s.log" % (rank if is_distributed else 0)),
        "DaLUKE pretraining on node %i" % rank,
        log_commit  = True,
        print_level = (Levels.INFO if quiet else Levels.DEBUG) if is_master else None,
        append      = resume,  # Append to existing log file if we are resuming training
    )

    post_time, post_command = parse_post_command(post_command)
    execute_post_command = False
    if post_time:
        log("Quitting in %.2f h and running command '%s'" % ((post_time-time.time())/3600, post_command))

    if resume:
        log("Resuming from %s" % name)
        # Load results and hyperparameters from earlier training
        res = TrainResults.load(location)
        # Close unended profiles
        close_tt(res.tt)
        TT.fuse(res.tt)
        res.tt = TT
        tmp_saved_pu = res.parameter_update
        loaded_params = Hyperparams.load(location)
        # Overwrite ff-size if given explicitly
        if "ff_size" in explicit_args:
            loaded_params.ff_size = params.ff_size
        params = loaded_params
    else:
        tmp_saved_pu = None
    log.section("Starting pretraining with the following hyperparameters", params)
    log("Training using %i workers" % num_workers)

    log("Reading metadata and entity vocabulary")
    with open(fpath(DatasetBuilder.metadata_file)) as f:
        metadata = json.load(f)
    with open(fpath(DatasetBuilder.entity_vocab_file)) as f:
        entity_vocab = json.load(f)
    log("Loaded metadata:", json.dumps(metadata, indent=4))
    log(f"Loaded entity vocabulary of {len(entity_vocab)} entities")
    if params.ent_min_mention:
        log("Removing entities with less than %i mentions" % params.ent_min_mention)
        entity_vocab = { ent: info for ent, info in entity_vocab.items()
            if info["count"] >= params.ent_min_mention or ent in {"[PAD]", "[UNK]", "[MASK]"} }
        log("After filtering, entity vocab now has %i entities" % len(entity_vocab))

    # Device should be cuda:rank or just cuda if single gpu, else cpu
    if is_distributed:
        device = torch.device("cuda", index=rank)
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log.debug(
        "Hardware for this worker:",
        "CPU: %s" % cpuinfo.get_cpu_info()["brand_raw"],
        "GPU: %s" % (torch.cuda.get_device_name(device) if torch.cuda.is_available() else "NA"),
        sep="\t\n",
    )

    if params.entity_loss_weight:
        log("Setting up loss function with entity loss weighting")
        # Don't weigh special tokens
        weights = torch.Tensor([0, 0, 0, *(1 / info["count"] for info in entity_vocab.values() if info["count"])]).to(device)
        entity_criterion = nn.CrossEntropyLoss(weight=weights)
    else:
        log("Setting up loss function without entity loss weighting")
        entity_criterion = nn.CrossEntropyLoss()
    word_criterion = nn.CrossEntropyLoss()
    loss_calculator = lambda w, e: params.word_ent_weight * w + (1 - params.word_ent_weight) * e

    # Load dataset and training results
    bert_config = AutoConfig.from_pretrained(metadata["base-model"])
    if metadata["reduced-vocab"]:
        token_map_file = fpath(DatasetBuilder.token_map_file)
        log("Loading token map from '%s'" % token_map_file)
        token_map = np.load(token_map_file)
        tokenizer = AutoTokenizer.from_pretrained(metadata["base-model"])
        *__, unk_id = get_special_ids(tokenizer)
        token_reduction = token_map_to_token_reduction(token_map, unk_id)
    else:
        token_map = None

    log("Building dataset")
    data = DataLoader(
        location,
        metadata,
        entity_vocab,
        device,
        params.word_mask_prob,
        params.word_unmask_prob,
        params.word_randword_prob,
        params.ent_mask_prob,
        vocab_size=metadata["vocab-size"],
        token_map=token_map,
        ent_min_mention=params.ent_min_mention,
    )
    sampler = (DistributedSampler if is_distributed else RandomSampler)(data.train_examples)
    log("Built %i examples" % len(data))

    loader = data.get_dataloader(params.ff_size, sampler)
    val_loader = data.get_dataloader(params.ff_size, SequentialSampler(data.val_examples), validation=True)

    # Number of subbatches in each parameter update (batch)
    grad_accumulation_steps = params.batch_size // (params.ff_size * num_workers)
    # How many full batches can be made from the dataset
    batches_in_data = len(data) // params.batch_size
    log(
        "Parameter updates:               %i" % params.parameter_updates,
        "Subbatches per parameter update: %i" % grad_accumulation_steps,
        "Subbatches generated:            %i" % len(loader),
        "Batches needed to cover dataset: %i" % batches_in_data,
    )

    if not resume:
        # Calculate parameter differences, when at least 20k examples have been seen
        paramdiff_every = ceil(MIN_EXAMPLES_PER_PARAMDIFF / params.batch_size)
        log("Recalculating parameter differences every %i'th parameter update" % paramdiff_every)
        top_k = [1, 3, 10]
        log("Calculating top %s accuracies" % top_k)
        if validate_every:
            val_updates = unique(np.array(
                np.arange(-1, params.parameter_updates, validate_every).tolist() + [params.parameter_updates-1]
            ))[1:]
        else:
            val_updates = np.array([], dtype=int)
        res = TrainResults(
            runtime           = np.zeros(params.parameter_updates),
            lr                = np.zeros(params.parameter_updates),
            parameter_update  = 0,

            losses            = np.zeros(params.parameter_updates),
            scaled_loss       = np.zeros(params.parameter_updates),

            top_k             = top_k,
            w_losses          = np.zeros(params.parameter_updates),
            e_losses          = np.zeros(params.parameter_updates),
            w_accuracies      = np.zeros((params.parameter_updates, len(top_k))),
            e_accuracies      = np.zeros((params.parameter_updates, len(top_k))),

            val_param_updates = val_updates,
            val_losses        = np.zeros(len(val_updates)),
            val_w_losses      = np.zeros(len(val_updates)),
            val_e_losses      = np.zeros(len(val_updates)),
            val_w_accuracies  = np.zeros((len(val_updates), len(top_k))),
            val_e_accuracies  = np.zeros((len(val_updates), len(top_k))),

            paramdiff_every   = paramdiff_every,
            groups_to_slices  = None,  # Set later
            orig_params       = None,
            paramdiff_1       = None,

            luke_exclusive_params = None,  # Set later
            att_mats_from_base    = None,  # Set later

            tt = TT,
        )

    save_pus = set(range(-1, params.parameter_updates, save_every)).union({params.parameter_updates-1})
    log("Saving model at parameter updates: %s" % sorted(save_pus),
        "Validating at parameter updates: %s" % res.val_param_updates.tolist())

    # Build model, possibly by loading previous weights
    log.section("Setting up model")
    bert_config = AutoConfig.from_pretrained(metadata["base-model"])
    if params.ent_hidden_size is None:
        params.ent_hidden_size = bert_config.hidden_size
    else:
        assert params.ent_hidden_size <= bert_config.hidden_size,\
            "Entity hidden size (%i) cannot be larger than hidden size in '%s' (%i)" % (
                params.hidden_size,
                metadata["base-model"],
                bert_config.hidden_size,
            )

    log("Initializing model")
    model_cls = BertAttentionPretrainTaskDaLUKE if params.bert_attention else PretrainTaskDaLUKE
    model = model_cls(
        bert_config,
        ent_vocab_size        = len(entity_vocab),
        ent_embed_size        = params.ent_embed_size,
        ent_hidden_size       = params.ent_hidden_size,
        ent_intermediate_size = params.ent_intermediate_size,
    ).to(device)
    bert_config.vocab_size = metadata["vocab-size"]
    log("Bert config", bert_config.to_json_string())

    if params.lukeinit:
        log("Initializing weights in accordance with LUKE")
        model.apply(lambda module: model.init_weights(module, bert_config.initializer_range))
    # Load parameters from base model
    if not params.no_base_model:
        log("Loading base model parameters")
        with TT.profile("Loading base model parameters"):
            base_model = AutoModelForPreTraining.from_pretrained(metadata["base-model"])
            new_weights = load_base_model_weights(
                model,
                base_model.state_dict(),
                params.bert_attention,
            )
            if metadata["reduced-vocab"]:
                log("Removing unneeded token weights")
                reduced_model = model_cls(
                    bert_config,
                    ent_vocab_size        = len(entity_vocab),
                    ent_embed_size        = params.ent_embed_size,
                    ent_hidden_size       = params.ent_hidden_size,
                    ent_intermediate_size = params.ent_intermediate_size,
                ).to(device)
                copy_with_reduced_state_dict(token_reduction, model, reduced_model)
                model = reduced_model
    else:
        new_weights = set(model.state_dict())
    # Initialize self-attention query matrices to BERT word query matrices
    att_mat_keys = set()
    if not params.bert_attention and not params.no_base_model:
        log("Initializing new attention matrices with%s PCA" % ("" if params.pcainit else "out"))
        att_mat_keys = model.init_special_attention(params.pcainit, device)
    if not resume:
        res.luke_exclusive_params = new_weights
        res.att_mats_from_base = att_mat_keys
        if is_master:
            res.orig_params = all_params(model).cpu().numpy()
    log("Pretraining model initialized with %s parameters" % thousand_seps(len(model)))

    # Unfixes params at this parameter update
    unfix_base_model_params_pu = round(params.bert_fix_prop * params.parameter_updates)
    log("Unfixing base model params after %i parameter updates" % unfix_base_model_params_pu)

    if resume:
        mpath = fpath((TrainResults.subfolder, MODEL_OUT.format(i=res.parameter_update)))
        log("Loading model from '%s'" % mpath)
        model.load_state_dict(torch.load(mpath, map_location=device))
        log(f"Resuming training saved at parameter update {res.parameter_update}")
    else:
        res.groups_to_slices, t = all_params_groups_to_slices(model, bert_config.num_hidden_layers)
        log("Parameter groups and positions", t)
        res.paramdiff_1 = { name: np.zeros(ceil(params.parameter_updates/res.paramdiff_every)) for name in res.groups_to_slices }
    if is_distributed:
        model = DDP(model, device_ids=[rank], find_unused_parameters=True)
    non_ddp_model = model.module if is_distributed else model

    log("Setting up optimizer, scaler, and learning rate scheduler")
    optimizer = get_optimizer(non_ddp_model, params.weight_decay, params.lr)
    scaler = amp.GradScaler() if params.fp16 else None
    scheduler = get_lr_scheduler(
        optimizer,
        int(params.warmup_prop * params.parameter_updates),
        params.parameter_updates,
        unfix_base_model_params_pu,
    )
    if resume:
        optimizer.load_state_dict(torch.load(fpath((TrainResults.subfolder, OPTIMIZER_OUT.format(i=res.parameter_update))), map_location=device))
        scheduler.load_state_dict(torch.load(fpath((TrainResults.subfolder, SCHEDULER_OUT.format(i=res.parameter_update))), map_location=device))
        if params.fp16:
            scaler.load_state_dict(torch.load(fpath((TrainResults.subfolder, SCALER_OUT.format(i=res.parameter_update))), map_location=device))
        res.parameter_update += 1  # We saved the data at pu i, but should now commence pu i+1

    log.debug("Time distribution before starting training", TT)
    log_memory_stats(device)

    log.section(f"Training DaLUKE for {params.parameter_updates} parameter updates")
    model.zero_grad()  # To avoid tracking of model parameter manipulation
    model.train()

    # Start with transfer learned weights locked
    fix_base_model_params(res.luke_exclusive_params, non_ddp_model, True)
    fixed_params = True

    # Save initial parameters
    if is_master and not resume:
        with TT.profile("Saving progress"):
            paths = save_training(location, params, model.module if is_distributed else model,
                res, optimizer, scheduler, scaler, -1)
            log.debug("Saved initial state to", *paths)

    batch_iter = iter(loader)
    for i in range(res.parameter_update, params.parameter_updates):
        TT.profile("Parameter update")
        res.parameter_update = i
        if i >= unfix_base_model_params_pu and fixed_params:
            log("Unfixing base model params")
            fix_base_model_params(res.luke_exclusive_params, model, False)
            fixed_params = False
        if is_distributed and i % batches_in_data == 0:
            sampler.set_epoch(i // batches_in_data)

        # Losses and accuracies for this parameter update
        t_loss, w_loss, e_loss, s_loss = 0, 0, 0, 0
        w_accuracies = np.zeros((grad_accumulation_steps, len(res.top_k)))
        e_accuracies = np.zeros((grad_accumulation_steps, len(res.top_k)))

        # Loop over enough batches to make a parameter update
        for j in range(grad_accumulation_steps):
            TT.profile("Sub-batch")
            try:
                batch = next(batch_iter)
            except StopIteration:
                batch_iter = iter(loader)
                batch = next(batch_iter)

            TT.profile("FP and gradients")
            with amp.autocast() if params.fp16 else contextlib.ExitStack():
                word_preds, ent_preds = model(batch)
                # Compute and backpropagate loss
                word_loss = word_criterion(word_preds, batch.word_mask_labels)
                ent_loss = entity_criterion(ent_preds, batch.ent_mask_labels)
                has_entities = not torch.isnan(ent_loss).item()
                ent_loss = torch.nan_to_num(ent_loss)
            loss = loss_calculator(word_loss, ent_loss)
            loss /= grad_accumulation_steps

            # Only sync parameters on grad updates, aka last pass of this loop
            with model.no_sync() if is_distributed and j < grad_accumulation_steps - 1 else contextlib.ExitStack():
                if params.fp16:
                    scaled_loss = scaler.scale(loss)
                    scaled_loss.backward()
                    s_loss += scaled_loss.item()
                else:
                    loss.backward()

            t_loss += loss.item()
            w_loss += word_loss.item() / grad_accumulation_steps
            e_loss += ent_loss.item() / grad_accumulation_steps if has_entities else 0

            if torch.cuda.is_available():
                torch.cuda.synchronize(rank if is_distributed else None)

            TT.end_profile()

            # Save accuracy for statistics
            if is_master:
                with TT.profile("Training accuracy"):
                    w_accuracies[j] = top_k_accuracy(batch.word_mask_labels, word_preds, res.top_k)
                    e_accuracies[j] = top_k_accuracy(batch.ent_mask_labels, ent_preds, res.top_k)

            TT.end_profile()

        # Update model parameters
        with TT.profile("Parameter step"):
            if params.fp16:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            scheduler.step()
            model.zero_grad()

        # Calculate how much gradient has changed
        if is_master and i % res.paramdiff_every == 0:
            with torch.no_grad(), TT.profile("Parameter changes"):
                log.debug("Calculating parameter changes")
                orig_pars = torch.from_numpy(res.orig_params).to(device)
                current_pars = all_params(model.module if is_distributed else model)
                absdiff = torch.abs(current_pars-orig_pars)
                for blockname, slice_ in res.groups_to_slices.items():
                    j = i // res.paramdiff_every
                    res.paramdiff_1[blockname][j] = absdiff[slice_].sum().item()
                del orig_pars, current_pars

        res.losses[i]       = t_loss
        res.w_losses[i]     = w_loss
        res.e_losses[i]     = e_loss
        res.scaled_loss[i]  = s_loss
        res.lr[i]           = scheduler.get_last_lr()[0]
        res.w_accuracies[i] = np.mean(w_accuracies, axis=0)
        res.e_accuracies[i] = np.nanmean(e_accuracies, axis=0)
        res.runtime[i]      = TT.end_profile()
        log.debug(
            "Performed parameter update %i / %i in %.2f s" % (i, params.parameter_updates-1, res.runtime[i]),
            f"  Loss (total, word, entity, scaled): {t_loss:9.4f}, {w_loss:9.4f}, {e_loss:9.4f}, {s_loss:.4f}",
            f"  Accuracy (word, entity): {100*res.w_accuracies[i, 0]:7.2f} %, {100*res.e_accuracies[i, 0]:7.2f} %",
        )

        if i in res.val_param_updates and is_master:
            TT.profile("Model validation")
            log("Validating model")
            vi = res.val_param_updates.tolist().index(i)
            res.val_w_losses[vi], res.val_e_losses[vi], res.val_w_accuracies[vi], res.val_e_accuracies[vi] =\
                validate_model(model, val_loader, word_criterion, entity_criterion, res.top_k)
            res.val_losses[vi] = loss_calculator(res.val_w_losses[vi], res.val_e_losses[vi])
            log(
                "Validation loss:",
                "  Total:  %9.4f" % res.val_losses[vi],
                "  Word:   %9.4f" % res.val_w_losses[vi],
                "  Entity: %9.4f" % res.val_e_losses[vi],
                "Validation accuracy:",
                "  Word:   %7.2f %%" % (100 * res.val_w_accuracies[vi, 0]),
                "  Entity: %7.2f %%" % (100 * res.val_e_accuracies[vi, 0]),
            )
            model.train()
            TT.end_profile()
            log.debug("Time distribution so far", TT)

        # Save results and model
        if is_master and i in save_pus:
            with TT.profile("Saving progress"):
                save_progress(location, i, tmp_saved_pu, save_pus, params,
                    model.module if is_distributed else model, res, optimizer, scheduler, scaler)
        if i in save_pus:
            log_memory_stats(device)

        # If timed out, save, quit, and run resume command
        if post_time and time.time() > post_time:
            log_memory_stats(device)
            log.section("Time limit reached. Quitting and running command '%s'" % post_command)
            with TT.profile("Saving progress"):
                save_progress(location, i, tmp_saved_pu, save_pus, params,
                    model.module if is_distributed else model, res, optimizer, scheduler, scaler)
            execute_post_command = True
            break

    log.debug("Time distribution", TT)

    # Clean up multi-gpu if used
    cleanup(rank)

    if is_master and execute_post_command:
        os.system(post_command)