Exemplo n.º 1
0
        dest="max_context_length",
        type=int,
        help="Maximum length of context. (Don't set to inherit from training config)",
    )

    # output folder
    parser.add_argument(
        "--output_path",
        dest="output_path",
        type=str,
        default="output",
        help="Path to the output.",
    )

    parser.add_argument(
        "--use_cuda", dest="use_cuda", action="store_true", default=False, help="run on gpu"
    )
    parser.add_argument(
        "--no_logger", dest="no_logger", action="store_true", default=False, help="don't log progress"
    )

    args = parser.parse_args()

    logger = None
    if not args.no_logger:
        logger = utils.get_logger(args.output_path)
        logger.setLevel(10)

    models = load_models(args, logger)
    run(args, logger, *models)
Exemplo n.º 2
0
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import logging
import numpy
import os
import time
import torch

from elq.index.faiss_indexer import DenseFlatIndexer, DenseIVFFlatIndexer, DenseHNSWFlatIndexer
import elq.candidate_ranking.utils as utils

logger = utils.get_logger()

def main(params): 
    output_path = params["output_path"]

    logger.info("Loading candidate encoding from path: %s" % params["candidate_encoding"])
    candidate_encoding = torch.load(params["candidate_encoding"])
    vector_size = candidate_encoding.size(1)
    index_buffer = params["index_buffer"]
    if params["faiss_index"] == "hnsw":
        logger.info("Using HNSW index in FAISS")
        index = DenseHNSWFlatIndexer(vector_size, index_buffer)
    elif params["faiss_index"] == "ivfflat":
        logger.info("Using IVF Flat index in FAISS")
        index = DenseIVFFlatIndexer(vector_size, 75, 100)
    else:
Exemplo n.º 3
0
biencoder_params["path_to_model"] = args.path_to_model
# entities to use
biencoder_params["entity_dict_path"] = args.entity_dict_path
biencoder_params["degug"] = False
biencoder_params["data_parallel"] = True
biencoder_params["no_cuda"] = False
biencoder_params["max_context_length"] = 32
biencoder_params["encode_batch_size"] = args.batch_size

saved_cand_ids = getattr(args, 'saved_cand_ids', None)
encoding_save_file_dir = args.encoding_save_file_dir
if encoding_save_file_dir is not None and not os.path.exists(
        encoding_save_file_dir):
    os.makedirs(encoding_save_file_dir, exist_ok=True)

logger = utils.get_logger(biencoder_params.get("model_output_path", None))
biencoder = load_biencoder(biencoder_params)
baseline_candidate_encoding = None
if getattr(args, 'compare_saved_embeds', None) is not None:
    baseline_candidate_encoding = torch.load(
        getattr(args, 'compare_saved_embeds'))

candidate_pool = load_candidate_pool(
    biencoder.tokenizer,
    biencoder_params,
    logger,
    getattr(args, 'saved_cand_ids', None),
)
if args.test:
    candidate_pool = candidate_pool[:10]
Exemplo n.º 4
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # Load train data
    train_samples = utils.read_dataset("train", params["data_path"])
    logger.info("Read %d train samples." % len(train_samples))
    logger.info("Finished reading all train samples")

    # Load eval data
    try:
        valid_samples = utils.read_dataset("valid", params["data_path"])
    except FileNotFoundError:
        valid_samples = utils.read_dataset("dev", params["data_path"])
    # MUST BE DIVISBLE BY n_gpus
    if len(valid_samples) > 1024:
        valid_subset = 1024
    else:
        valid_subset = len(
            valid_samples) - len(valid_samples) % torch.cuda.device_count()
    logger.info("Read %d valid samples, choosing %d subset" %
                (len(valid_samples), valid_subset))

    valid_data, valid_tensor_data, extra_ret_values = process_mention_data(
        samples=valid_samples[:valid_subset],  # use subset of valid data
        tokenizer=tokenizer,
        max_context_length=params["max_context_length"],
        max_cand_length=params["max_cand_length"],
        context_key=params["context_key"],
        title_key=params["title_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
        add_mention_bounds=(not args.no_mention_bounds),
        candidate_token_ids=None,
        params=params,
    )
    candidate_token_ids = extra_ret_values["candidate_token_ids"]
    valid_tensor_data = TensorDataset(*valid_tensor_data)
    valid_sampler = SequentialSampler(valid_tensor_data)
    valid_dataloader = DataLoader(valid_tensor_data,
                                  sampler=valid_sampler,
                                  batch_size=eval_batch_size)

    # load candidate encodings
    cand_encs = None
    cand_encs_index = None
    if params["freeze_cand_enc"]:
        cand_encs = torch.load(params['cand_enc_path'])
        logger.info("Loaded saved entity encodings")
        if params["debug"]:
            cand_encs = cand_encs[:200]

        # build FAISS index
        cand_encs_index = DenseHNSWFlatIndexer(1)
        cand_encs_index.deserialize_from(params['index_path'])
        logger.info("Loaded FAISS index on entity encodings")
        num_neighbors = 10

    # evaluate before training
    results = evaluate(
        reranker,
        valid_dataloader,
        params,
        cand_encs=cand_encs,
        device=device,
        logger=logger,
        faiss_index=cand_encs_index,
    )

    number_of_samples_per_dataset = {}

    time_start = time.time()

    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))

    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, distributed training: {}".format(
        device, n_gpu, False))

    num_train_epochs = params["num_train_epochs"]
    if params["dont_distribute_train_samples"]:
        num_samples_per_batch = len(train_samples)

        train_data, train_tensor_data_tuple, extra_ret_values = process_mention_data(
            samples=train_samples,
            tokenizer=tokenizer,
            max_context_length=params["max_context_length"],
            max_cand_length=params["max_cand_length"],
            context_key=params["context_key"],
            title_key=params["title_key"],
            silent=params["silent"],
            logger=logger,
            debug=params["debug"],
            add_mention_bounds=(not args.no_mention_bounds),
            candidate_token_ids=candidate_token_ids,
            params=params,
        )
        logger.info("Finished preparing training data")
    else:
        num_samples_per_batch = len(train_samples) // num_train_epochs

    trainer_path = params.get("path_to_trainer_state", None)
    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, num_samples_per_batch, logger)
    if trainer_path is not None and os.path.exists(trainer_path):
        training_state = torch.load(trainer_path)
        optimizer.load_state_dict(training_state["optimizer"])
        scheduler.load_state_dict(training_state["scheduler"])
        logger.info("Loaded saved training state")

    model.train()

    best_epoch_idx = -1
    best_score = -1
    logger.info("Num samples per batch : %d" % num_samples_per_batch)
    for epoch_idx in trange(params["last_epoch"] + 1,
                            int(num_train_epochs),
                            desc="Epoch"):
        tr_loss = 0
        results = None

        if not params["dont_distribute_train_samples"]:
            start_idx = epoch_idx * num_samples_per_batch
            end_idx = (epoch_idx + 1) * num_samples_per_batch

            train_data, train_tensor_data_tuple, extra_ret_values = process_mention_data(
                samples=train_samples[start_idx:end_idx],
                tokenizer=tokenizer,
                max_context_length=params["max_context_length"],
                max_cand_length=params["max_cand_length"],
                context_key=params["context_key"],
                title_key=params["title_key"],
                silent=params["silent"],
                logger=logger,
                debug=params["debug"],
                add_mention_bounds=(not args.no_mention_bounds),
                candidate_token_ids=candidate_token_ids,
                params=params,
            )
            logger.info(
                "Finished preparing training data for epoch {}: {} samples".
                format(epoch_idx, len(train_tensor_data_tuple[0])))

        batch_train_tensor_data = TensorDataset(*list(train_tensor_data_tuple))
        if params["shuffle"]:
            train_sampler = RandomSampler(batch_train_tensor_data)
        else:
            train_sampler = SequentialSampler(batch_train_tensor_data)

        train_dataloader = DataLoader(batch_train_tensor_data,
                                      sampler=train_sampler,
                                      batch_size=train_batch_size)

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_input = batch[0]
            candidate_input = batch[1]
            label_ids = batch[2] if params["freeze_cand_enc"] else None
            mention_idxs = batch[-2]
            mention_idx_mask = batch[-1]
            if params["debug"] and label_ids is not None:
                label_ids[label_ids > 199] = 199

            cand_encs_input = None
            label_input = None
            mention_reps_input = None
            mention_logits = None
            mention_bounds = None
            hard_negs_mask = None
            if params["adversarial_training"]:
                assert cand_encs is not None and label_ids is not None  # due to params["freeze_cand_enc"] being set
                '''
                GET CLOSEST N CANDIDATES (AND APPROPRIATE LABELS)
                '''
                # (bs, num_spans, embed_size)
                pos_cand_encs_input = cand_encs[label_ids.to("cpu")]
                pos_cand_encs_input[~mention_idx_mask] = 0

                context_outs = reranker.encode_context(
                    context_input,
                    gold_mention_bounds=mention_idxs,
                    gold_mention_bounds_mask=mention_idx_mask,
                    get_mention_scores=True,
                )
                mention_logits = context_outs['all_mention_logits']
                mention_bounds = context_outs['all_mention_bounds']
                mention_reps = context_outs['mention_reps']
                # mention_reps: (bs, max_num_spans, embed_size) -> masked_mention_reps: (all_pred_mentions_batch, embed_size)
                masked_mention_reps = mention_reps[
                    context_outs['mention_masks']]

                # neg_cand_encs_input_idxs: (all_pred_mentions_batch, num_negatives)
                _, neg_cand_encs_input_idxs = cand_encs_index.search_knn(
                    masked_mention_reps.detach().cpu().numpy(), num_neighbors)
                neg_cand_encs_input_idxs = torch.from_numpy(
                    neg_cand_encs_input_idxs)
                # set "correct" closest entities to -1
                # masked_label_ids: (all_pred_mentions_batch)
                masked_label_ids = label_ids[mention_idx_mask]
                # neg_cand_encs_input_idxs: (max_spans_in_batch, num_negatives)
                neg_cand_encs_input_idxs[
                    neg_cand_encs_input_idxs -
                    masked_label_ids.to("cpu").unsqueeze(-1) == 0] = -1

                # reshape back tensor (extract num_spans dimension)
                # (bs, num_spans, num_negatives)
                neg_cand_encs_input_idxs_reconstruct = torch.zeros(
                    label_ids.size(0),
                    label_ids.size(1),
                    neg_cand_encs_input_idxs.size(-1),
                    dtype=neg_cand_encs_input_idxs.dtype)
                neg_cand_encs_input_idxs_reconstruct[
                    mention_idx_mask] = neg_cand_encs_input_idxs
                neg_cand_encs_input_idxs = neg_cand_encs_input_idxs_reconstruct

                # create neg_example_idx (corresponding example (in batch) for each negative)
                # neg_example_idx: (bs * num_negatives)
                neg_example_idx = torch.arange(
                    neg_cand_encs_input_idxs.size(0)).unsqueeze(-1)
                neg_example_idx = neg_example_idx.expand(
                    neg_cand_encs_input_idxs.size(0),
                    neg_cand_encs_input_idxs.size(2))
                neg_example_idx = neg_example_idx.flatten()

                # flatten and filter -1 (i.e. any correct/positive entities)
                # neg_cand_encs_input_idxs: (bs * num_negatives, num_spans)
                neg_cand_encs_input_idxs = neg_cand_encs_input_idxs.permute(
                    0, 2, 1)
                neg_cand_encs_input_idxs = neg_cand_encs_input_idxs.reshape(
                    -1, neg_cand_encs_input_idxs.size(-1))
                # mask invalid negatives (actually the positive example)
                # (bs * num_negatives)
                mask = ~((neg_cand_encs_input_idxs == -1).sum(1).bool()
                         )  # rows without any -1 entry
                # deletes corresponding negative for *all* spans in that example (deletes at most 3 of 10 negatives / example)
                # neg_cand_encs_input_idxs: (bs * num_negatives - invalid_negs, num_spans)
                neg_cand_encs_input_idxs = neg_cand_encs_input_idxs[mask]
                # neg_cand_encs_input_idxs: (bs * num_negatives - invalid_negs)
                neg_example_idx = neg_example_idx[mask]
                # (bs * num_negatives - invalid_negs, num_spans, embed_size)
                neg_cand_encs_input = cand_encs[neg_cand_encs_input_idxs]
                # (bs * num_negatives - invalid_negs, num_spans, embed_size)
                neg_mention_idx_mask = mention_idx_mask[neg_example_idx]
                neg_cand_encs_input[~neg_mention_idx_mask] = 0

                # create input tensors (concat [pos examples, neg examples])
                # (bs + bs * num_negatives, num_spans, embed_size)
                mention_reps_input = torch.cat([
                    mention_reps,
                    mention_reps[neg_example_idx.to(device)],
                ])
                assert mention_reps.size(0) == pos_cand_encs_input.size(0)

                # (bs + bs * num_negatives, num_spans)
                label_input = torch.cat([
                    torch.ones(pos_cand_encs_input.size(0),
                               pos_cand_encs_input.size(1),
                               dtype=label_ids.dtype),
                    torch.zeros(neg_cand_encs_input.size(0),
                                neg_cand_encs_input.size(1),
                                dtype=label_ids.dtype),
                ]).to(device)
                # (bs + bs * num_negatives, num_spans, embed_size)
                cand_encs_input = torch.cat([
                    pos_cand_encs_input,
                    neg_cand_encs_input,
                ]).to(device)
                hard_negs_mask = torch.cat(
                    [mention_idx_mask, neg_mention_idx_mask])

            loss, _, _, _ = reranker(
                context_input,
                candidate_input,
                cand_encs=cand_encs_input,
                text_encs=mention_reps_input,
                mention_logits=mention_logits,
                mention_bounds=mention_bounds,
                label_input=label_input,
                gold_mention_bounds=mention_idxs,
                gold_mention_bounds_mask=mention_idx_mask,
                hard_negs_mask=hard_negs_mask,
                return_loss=True,
            )

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                loss = None  # for GPU mem management
                mention_reps = None
                mention_reps_input = None
                label_input = None
                cand_encs_input = None

                evaluate(
                    reranker,
                    valid_dataloader,
                    params,
                    cand_encs=cand_encs,
                    device=device,
                    logger=logger,
                    faiss_index=cand_encs_index,
                    get_losses=params["get_losses"],
                )
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine - tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)
        torch.save(
            {
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
            }, os.path.join(epoch_output_folder_path, "training_state.th"))

        output_eval_file = os.path.join(epoch_output_folder_path,
                                        "eval_results.txt")
        logger.info("Valid data evaluation")
        results = evaluate(
            reranker,
            valid_dataloader,
            params,
            cand_encs=cand_encs,
            device=device,
            logger=logger,
            faiss_index=cand_encs_index,
            get_losses=params["get_losses"],
        )
        logger.info("Train data evaluation")
        results = evaluate(
            reranker,
            train_dataloader,
            params,
            cand_encs=cand_encs,
            device=device,
            logger=logger,
            faiss_index=cand_encs_index,
            get_losses=params["get_losses"],
        )

        ls = [best_score, results["normalized_f1"]]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(model_output_path,
                                           "epoch_{}".format(best_epoch_idx))
    utils.save_model(reranker.model, tokenizer, model_output_path)

    if params["evaluate"]:
        params["path_to_model"] = model_output_path
        evaluate(params,
                 cand_encs=cand_encs,
                 logger=logger,
                 faiss_index=cand_encs_index)