Exemplo n.º 1
0
def load_models(
        biencoder_config='/home/azureuser/test/BLINK/models/biencoder_wiki_large.json',
        biencoder_model='/home/azureuser/test/BLINK/models/biencoder_wiki_large.bin',
        entity_catalogue='/home/azureuser/test/BLINK/models/entity.jsonl',
        entity_encoding='/home/azureuser/test/BLINK/models/all_entities_large.t7',
        logger=None):
    # load biencoder model
    if logger:
        logger.info("loading biencoder model")
    with open(biencoder_config) as json_file:
        biencoder_params = json.load(json_file)
        biencoder_params["path_to_model"] = biencoder_model
    biencoder = load_biencoder(biencoder_params)

    # load candidate entities
    if logger:
        logger.info("loading candidate entities")
    (
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
    ) = load_candidates(
        entity_catalogue,
        entity_encoding,
        faiss_index=None,
        index_path=None,
        logger=logger,
    )

    return biencoder, biencoder_params, candidate_encoding, faiss_indexer
Exemplo n.º 2
0
    def __init__(self, args, logger=None) -> None:
        self.args = args
        self.logger = logger
        # Load NER model
        self.ner_model = NER.get_model()
        # load biencoder model
        if logger:
            logger.info("loading biencoder model")
        with open(args.biencoder_config) as json_file:
            biencoder_params = json.load(json_file)
            biencoder_params["path_to_model"] = args.biencoder_model
        biencoder = load_biencoder(biencoder_params)

        crossencoder = None
        crossencoder_params = None
        if not args.fast:
            # load crossencoder model
            if logger:
                logger.info("loading crossencoder model")
            with open(args.crossencoder_config) as json_file:
                crossencoder_params = json.load(json_file)
                crossencoder_params["path_to_model"] = args.crossencoder_model
            crossencoder = load_crossencoder(crossencoder_params)

        # load candidate entities
        if logger:
            logger.info("loading candidate entities")
        (
            candidate_encoding,
            title2id,
            id2title,
            id2text,
            wikipedia_id2local_id,
            faiss_indexer,
        ) = _load_candidates(
            args.entity_catalogue,
            args.entity_encoding,
            faiss_index=args.faiss_index,
            index_path=args.index_path,
            logger=logger,
        )

        self.biencoder = biencoder
        self.biencoder_params = biencoder_params
        self.crossencoder = crossencoder
        self.crossencoder_params = crossencoder_params
        self.candidate_encoding = candidate_encoding
        self.title2id = title2id
        self.id2title = id2title
        self.id2text = id2text
        self.wikipedia_id2local_id = wikipedia_id2local_id
        self.faiss_indexer = faiss_indexer
        self.id2url = {
            v: "https://en.wikipedia.org/wiki?curid=%s" % k
            for k, v in wikipedia_id2local_id.items()
        }
Exemplo n.º 3
0
def load_models(args, logger=None):

    # load biencoder model
    if logger:
        logger.info("loading biencoder model")
    with open(args.biencoder_config) as json_file:
        biencoder_params = json.load(json_file)
        biencoder_params["path_to_model"] = args.biencoder_model
    biencoder = load_biencoder(biencoder_params)

    crossencoder = None
    crossencoder_params = None
    if not args.fast:
        # load crossencoder model
        if logger:
            logger.info("loading crossencoder model")
        with open(args.crossencoder_config) as json_file:
            crossencoder_params = json.load(json_file)
            crossencoder_params["path_to_model"] = args.crossencoder_model
        crossencoder = load_crossencoder(crossencoder_params)

    # load candidate entities
    if logger:
        logger.info("loading candidate entities")
    (
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
    ) = _load_candidates(args.entity_catalogue,
                         args.entity_encoding,
                         faiss_index=args.faiss_index,
                         index_path=args.index_path,
                         logger=logger)

    local_id2wikipedia_id = {v: k for k, v in wikipedia_id2local_id.items()}
    if True:
        with open('local_id2wikipedia_id.json', 'w') as json_file:
            json.dump(local_id2wikipedia_id, json_file, indent=4)

    return (
        biencoder,
        biencoder_params,
        crossencoder,
        crossencoder_params,
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
    )
Exemplo n.º 4
0
def load_models(args, logger=None):

    # load biencoder model
    if logger:
        logger.info("loading biencoder model")
    with open(args.biencoder_config) as json_file:
        biencoder_params = json.load(json_file)
        biencoder_params["path_to_model"] = args.biencoder_model
    biencoder = load_biencoder(biencoder_params)

    crossencoder = None
    crossencoder_params = None
    if not args.fast:
        # load crossencoder model
        if logger:
            logger.info("loading crossencoder model")
        with open(args.crossencoder_config) as json_file:
            crossencoder_params = json.load(json_file)
            crossencoder_params["path_to_model"] = args.crossencoder_model
        crossencoder = load_crossencoder(crossencoder_params)

    # load candidate entities
    if logger:
        logger.info("loading candidate entities")
    (
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
    ) = _load_candidates(
        args.entity_catalogue,
        args.entity_encoding,
        faiss_index=getattr(args, 'faiss_index', None),
        index_path=getattr(args, 'index_path', None),
        logger=logger,
    )

    return (
        biencoder,
        biencoder_params,
        crossencoder,
        crossencoder_params,
        candidate_encoding,
        title2id,
        id2title,
        id2text,
        wikipedia_id2local_id,
        faiss_indexer,
    )
Exemplo n.º 5
0
def main(params):
    model_output_path = params["output_path"]
    if not os.path.exists(model_output_path):
        os.makedirs(model_output_path)
    logger = utils.get_logger(params["output_path"])

    # Init model
    reranker = BiEncoderRanker(params)
    tokenizer = reranker.tokenizer
    model = reranker.model

    device = reranker.device
    n_gpu = reranker.n_gpu

    if params["gradient_accumulation_steps"] < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(params["gradient_accumulation_steps"]))

    # An effective batch size of `x`, when we are accumulating the gradient accross `y` batches will be achieved by having a batch size of `z = x / y`
    # args.gradient_accumulation_steps = args.gradient_accumulation_steps // n_gpu
    params["train_batch_size"] = (params["train_batch_size"] //
                                  params["gradient_accumulation_steps"])
    train_batch_size = params["train_batch_size"]
    eval_batch_size = params["eval_batch_size"]
    grad_acc_steps = params["gradient_accumulation_steps"]

    # Fix the random seeds
    seed = params["seed"]
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if reranker.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)

    # Load train data
    train_samples = utils.read_dataset("train", params["data_path"])
    logger.info("Read %d train samples." % len(train_samples))

    train_data, train_tensor_data = data.process_mention_data(
        train_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params["context_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    if params["shuffle"]:
        train_sampler = RandomSampler(train_tensor_data)
    else:
        train_sampler = SequentialSampler(train_tensor_data)

    train_dataloader = DataLoader(train_tensor_data,
                                  sampler=train_sampler,
                                  batch_size=train_batch_size)

    # Load eval data
    # TODO: reduce duplicated code here
    valid_samples = utils.read_dataset("valid", params["data_path"])
    logger.info("Read %d valid samples." % len(valid_samples))

    valid_data, valid_tensor_data = data.process_mention_data(
        valid_samples,
        tokenizer,
        params["max_context_length"],
        params["max_cand_length"],
        context_key=params["context_key"],
        silent=params["silent"],
        logger=logger,
        debug=params["debug"],
    )
    valid_sampler = SequentialSampler(valid_tensor_data)
    valid_dataloader = DataLoader(valid_tensor_data,
                                  sampler=valid_sampler,
                                  batch_size=eval_batch_size)

    # evaluate before training
    results = evaluate(
        reranker,
        valid_dataloader,
        params,
        device=device,
        logger=logger,
    )

    number_of_samples_per_dataset = {}

    time_start = time.time()

    utils.write_to_file(os.path.join(model_output_path, "training_params.txt"),
                        str(params))

    logger.info("Starting training")
    logger.info("device: {} n_gpu: {}, distributed training: {}".format(
        device, n_gpu, False))

    optimizer = get_optimizer(model, params)
    scheduler = get_scheduler(params, optimizer, len(train_tensor_data),
                              logger)

    model.train()

    best_epoch_idx = -1
    best_score = -1

    num_train_epochs = params["num_train_epochs"]
    for epoch_idx in trange(int(num_train_epochs), desc="Epoch"):
        tr_loss = 0
        results = None

        if params["silent"]:
            iter_ = train_dataloader
        else:
            iter_ = tqdm(train_dataloader, desc="Batch")

        for step, batch in enumerate(iter_):
            batch = tuple(t.to(device) for t in batch)
            context_input, candidate_input, _, _ = batch
            loss, _ = reranker(context_input, candidate_input)

            # if n_gpu > 1:
            #     loss = loss.mean() # mean() to average on multi-gpu.

            if grad_acc_steps > 1:
                loss = loss / grad_acc_steps

            tr_loss += loss.item()

            if (step + 1) % (params["print_interval"] * grad_acc_steps) == 0:
                logger.info("Step {} - epoch {} average loss: {}\n".format(
                    step,
                    epoch_idx,
                    tr_loss / (params["print_interval"] * grad_acc_steps),
                ))
                tr_loss = 0

            loss.backward()

            if (step + 1) % grad_acc_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               params["max_grad_norm"])
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            if (step + 1) % (params["eval_interval"] * grad_acc_steps) == 0:
                logger.info("Evaluation on the development dataset")
                evaluate(
                    reranker,
                    valid_dataloader,
                    params,
                    device=device,
                    logger=logger,
                )
                model.train()
                logger.info("\n")

        logger.info("***** Saving fine - tuned model *****")
        epoch_output_folder_path = os.path.join(model_output_path,
                                                "epoch_{}".format(epoch_idx))
        utils.save_model(model, tokenizer, epoch_output_folder_path)

        output_eval_file = os.path.join(epoch_output_folder_path,
                                        "eval_results.txt")
        results = evaluate(
            reranker,
            valid_dataloader,
            params,
            device=device,
            logger=logger,
        )

        ls = [best_score, results["normalized_accuracy"]]
        li = [best_epoch_idx, epoch_idx]

        best_score = ls[np.argmax(ls)]
        best_epoch_idx = li[np.argmax(ls)]
        logger.info("\n")

    execution_time = (time.time() - time_start) / 60
    utils.write_to_file(
        os.path.join(model_output_path, "training_time.txt"),
        "The training took {} minutes\n".format(execution_time),
    )
    logger.info("The training took {} minutes\n".format(execution_time))

    # save the best model in the parent_dir
    logger.info("Best performance in epoch: {}".format(best_epoch_idx))
    params["path_to_model"] = os.path.join(
        model_output_path,
        "epoch_{}".format(best_epoch_idx),
        WEIGHTS_NAME,
    )
    reranker = load_biencoder(params)
    utils.save_model(reranker.model, tokenizer, model_output_path)

    if params["evaluate"]:
        params["path_to_model"] = model_output_path
        evaluate(params, logger=logger)