Ejemplo n.º 1
0
def main():
    print("===experiment starts===")
    exp_start_time = time.time()
    P = Params()
    opts = P.opts
    experiment_logdir = experiment_logger(args=opts)
    print("experiment_logdir:", experiment_logdir)
    P.dump_params(experiment_dir=experiment_logdir)
    cuda_devices = cuda_device_parser(str_ids=opts.cuda_devices)
    TRAIN_WORLDS, DEV_WORLDS, TEST_WORLDS = worlds_loader(args=opts)

    vocab = Vocabulary()
    iterator_for_training_and_evaluating_mentions = BucketIterator(
        batch_size=opts.batch_size_for_train,
        sorting_keys=[('context', 'num_tokens')])
    iterator_for_training_and_evaluating_mentions.index_with(vocab)

    embloader = EmbLoader(args=opts)
    emb_mapper, emb_dim, textfieldEmbedder = embloader.emb_returner()
    tokenIndexing = TokenIndexerReturner(args=opts)
    global_tokenizer = tokenIndexing.berttokenizer_returner()
    global_tokenIndexer = tokenIndexing.token_indexer_returner()

    if opts.load_from_checkpoint:
        mention_encoder, entity_encoder, model = load_model_objects(
            model_path=opts.model_path,
            mention_encoder_filename=opts.mention_encoder_filename,
            entity_encoder_filename=opts.entity_encoder_filename,
            model_filename=opts.model_filename)
    else:
        mention_encoder = Pooler_for_mention(args=opts,
                                             word_embedder=textfieldEmbedder)
        entity_encoder = Pooler_for_title_and_desc(
            args=opts, word_embedder=textfieldEmbedder)
        model = Biencoder(args=opts,
                          mention_encoder=mention_encoder,
                          entity_encoder=entity_encoder,
                          vocab=vocab)
    model = model.cuda()

    optimizer = optim.Adam(filter(lambda param: param.requires_grad,
                                  model.parameters()),
                           lr=opts.lr,
                           eps=opts.epsilon,
                           weight_decay=opts.weight_decay,
                           betas=(opts.beta1, opts.beta2),
                           amsgrad=opts.amsgrad)
    devEvalEpochs = [j for j in range(1, 1000)] if opts.add_hard_negatives else \
                    [1, 3, 5] + [k * 10 for k in range(1, 100)]

    for epoch in range(opts.num_epochs):
        oneep_train_start = time.time()
        for world_name in TRAIN_WORLDS:
            reader = WorldsReader(args=opts,
                                  world_name=world_name,
                                  token_indexers=global_tokenIndexer,
                                  tokenizer=global_tokenizer)

            if opts.add_hard_negatives:
                with torch.no_grad():
                    mention_encoder.eval(), entity_encoder.eval()
                    hardNegativeSearcher = HardNegativesSearcherForEachEpochStart(
                        args=opts,
                        world_name=world_name,
                        reader=reader,
                        embedder=textfieldEmbedder,
                        mention_encoder=mention_encoder,
                        entity_encoder=entity_encoder,
                        vocab=vocab,
                        berttokenizer=global_tokenizer,
                        bertindexer=global_tokenIndexer)
                    hardNegativeSearcher.hardNegativesSearcherandSetter()

            trains = reader.read('train')
            mention_encoder.train(), entity_encoder.train()
            trainer = Trainer(
                model=model,
                optimizer=optimizer,
                iterator=iterator_for_training_and_evaluating_mentions,
                train_dataset=trains,
                cuda_device=cuda_devices,
                num_epochs=1)
            trainer.train()

        if epoch + 1 in devEvalEpochs:
            print('\n===================\n', 'TEMP DEV EVALUATION@ Epoch',
                  epoch + 1, '\n===================\n')
            t_entire_h1c, t_entire_h10c, t_entire_h50c, t_entire_h64c, t_entire_h100c, t_entire_h500c, t_entire_datapoints \
                = oneLineLoaderForDevOrTestEvaluation(
                dev_or_test_flag='dev',
                opts=opts,
                global_tokenIndexer=global_tokenIndexer,
                global_tokenizer=global_tokenizer,
                textfieldEmbedder=textfieldEmbedder,
                mention_encoder=mention_encoder,
                entity_encoder=entity_encoder,
                vocab=vocab,
                experiment_logdir=experiment_logdir,
                finalEvalFlag=0,
                trainEpoch=epoch+1)

            result = oneLineLoaderForDevOrTestEvaluationRawData(
                dev_or_test_flag='dev',
                opts=opts,
                global_tokenIndexer=global_tokenIndexer,
                global_tokenizer=global_tokenizer,
                textfieldEmbedder=textfieldEmbedder,
                mention_encoder=mention_encoder,
                entity_encoder=entity_encoder,
                vocab=vocab,
                experiment_logdir=experiment_logdir,
                finalEvalFlag=0,
                trainEpoch=epoch + 1)

            devEvalExperimentEntireDevWorldLog(experiment_logdir,
                                               t_entire_h1c,
                                               t_entire_h10c,
                                               t_entire_h50c,
                                               t_entire_h64c,
                                               t_entire_h100c,
                                               t_entire_h500c,
                                               t_entire_datapoints,
                                               epoch=epoch)

            devEvalExperimentEntireDevWorldLogRawData(experiment_logdir,
                                                      result,
                                                      epoch=epoch)

        oneep_train_end = time.time()
        print('epoch {0} train time'.format(epoch + 1),
              oneep_train_end - oneep_train_start, 'sec')

        if opts.save_checkpoints:
            save_model_objects(
                model_object=model,
                mention_encoder_object=mention_encoder,
                entity_encoder_object=entity_encoder,
                model_path=experiment_logdir,
                mention_encoder_filename=opts.mention_encoder_filename,
                entity_encoder_filename=opts.entity_encoder_filename,
                model_filename=opts.model_filename,
                epoch=epoch)

    print('====training finished=======')

    with torch.no_grad():
        model.eval()
        print('===FINAL Evaluation starts===')

        for dev_or_test_flag in ['dev', 'test']:
            print('\n===================\n', dev_or_test_flag, 'EVALUATION',
                  '\n===================\n')
            entire_h1c, entire_h10c, entire_h50c, entire_h64c, entire_h100c, entire_h500c, entire_datapoints \
                = oneLineLoaderForDevOrTestEvaluation(dev_or_test_flag=dev_or_test_flag,
                                                      opts=opts,
                                                      global_tokenIndexer=global_tokenIndexer,
                                                      global_tokenizer=global_tokenizer,
                                                      textfieldEmbedder=textfieldEmbedder,
                                                      mention_encoder=mention_encoder,
                                                      entity_encoder=entity_encoder,
                                                      vocab=vocab,
                                                      experiment_logdir=experiment_logdir,
                                                      finalEvalFlag=1,
                                                      trainEpoch=-1)

            result \
                = oneLineLoaderForDevOrTestEvaluationRawData(dev_or_test_flag=dev_or_test_flag,
                                                      opts=opts,
                                                      global_tokenIndexer=global_tokenIndexer,
                                                      global_tokenizer=global_tokenizer,
                                                      textfieldEmbedder=textfieldEmbedder,
                                                      mention_encoder=mention_encoder,
                                                      entity_encoder=entity_encoder,
                                                      vocab=vocab,
                                                      experiment_logdir=experiment_logdir,
                                                      finalEvalFlag=1,
                                                      trainEpoch=-1)

            dev_or_test_finallog(
                entire_h1c,
                entire_h10c,
                entire_h50c,
                entire_h64c,
                entire_h100c,
                entire_h500c,
                entire_datapoints,
                dev_or_test_flag,
                experiment_logdir,
            )

            dev_or_test_finallog_rawdata(result, experiment_logdir,
                                         dev_or_test_flag)

    exp_end_time = time.time()
    print('===experiment finised', exp_end_time - exp_start_time, 'sec')
    print(experiment_logdir)
Ejemplo n.º 2
0
def main():
    exp_start_time = time.time()
    Parameters = Biencoder_params()
    opts = Parameters.get_params()
    experiment_logdir = experiment_logger(args=opts)
    Parameters.dump_params(experiment_dir=experiment_logdir)
    cuda_devices = cuda_device_parser(str_ids=opts.cuda_devices)

    reader_for_mentions = FixedDatasetTokenizedReader(
        args=opts,
        canonical_and_def_connecttoken=CANONICAL_AND_DEF_CONNECTTOKEN)
    trains = reader_for_mentions.read('train')
    if not opts.allen_lazyload:
        print('\ntrain statistics:', len(trains), '\n')

    vocab = Vocabulary()
    iterator_for_training_and_evaluating_mentions = BucketIterator(
        batch_size=opts.batch_size_for_train,
        sorting_keys=[('context', 'num_tokens')])
    iterator_for_training_and_evaluating_mentions.index_with(vocab)

    embloader = EmbLoader(args=opts)
    emb_mapper, emb_dim, textfieldEmbedder = embloader.emb_returner()

    if opts.model_for_training == 'blink_implementation_inbatchencoder':
        mention_encoder = Pooler_for_blink_mention(
            args=opts, word_embedder=textfieldEmbedder)
    else:
        mention_encoder = Concatenate_Right_and_Left_MentionEncoder(
            args=opts, input_dim=emb_dim, word_embedder=textfieldEmbedder)

    current_cui2idx, current_idx2cui, current_cui2emb, current_cui2cano, current_cui2def = reader_for_mentions.currently_stored_KB_dataset_returner(
    )

    if opts.model_for_training == 'biencoder':
        entity_encoder = Pooler_for_cano_and_def(
            args=opts, word_embedder=textfieldEmbedder)
        model = InBatchBiencoder(args=opts,
                                 mention_encoder=mention_encoder,
                                 entity_encoder=entity_encoder,
                                 vocab=vocab,
                                 input_dim=emb_dim)
    elif opts.model_for_training == 'blink_implementation_inbatchencoder':
        entity_encoder = Pooler_for_cano_and_def(
            args=opts, word_embedder=textfieldEmbedder)
        model = InBatchBLINKBiencoder(args=opts,
                                      mention_encoder=mention_encoder,
                                      entity_encoder=entity_encoder,
                                      vocab=vocab)
    else:
        print('currently', opts.model_for_training, 'are not supported')
        raise NotImplementedError
    model = model.cuda()

    if not opts.debug_for_entity_encoder:
        optimizer = optim.Adam(filter(lambda param: param.requires_grad,
                                      model.parameters()),
                               lr=opts.lr,
                               eps=opts.epsilon,
                               weight_decay=opts.weight_decay,
                               betas=(opts.beta1, opts.beta2),
                               amsgrad=opts.amsgrad)
        trainer = Trainer(
            model=model,
            optimizer=optimizer,
            iterator=iterator_for_training_and_evaluating_mentions,
            train_dataset=trains,
            # validation_dataset=devs,
            cuda_device=cuda_devices,
            num_epochs=opts.num_epochs)
        trainer.train()
    else:
        print('\n==Skip Biencoder training==\n')

    # Save the model
    serialization_dir = 'model'
    config_file = os.path.join(serialization_dir, 'config.json')
    vocabulary_dir = os.path.join(serialization_dir, 'vocabulary')
    weights_file = os.path.join(serialization_dir, 'weights.th')
    model_pytorch_file = os.path.join(serialization_dir, 'model.th')
    os.makedirs(serialization_dir, exist_ok=True)
    #params.to_file(config_file)
    #vocab.save_to_files(vocabulary_dir)
    torch.save(model, model_pytorch_file)

    with torch.no_grad():
        model.eval()
        model.switch2eval()

        print(
            '======Start encoding all entities in KB=====\n======1. Start Tokenizing All Entities====='
        )
        entity_encoder_wrapping_model = WrappedModel_for_entityencoding(
            args=opts, entity_encoder=entity_encoder, vocab=vocab)
        entity_encoder_wrapping_model.eval()
        # entity_encoder_wrapping_model.cpu()

        Tokenizer = reader_for_mentions.berttokenizer_returner()
        TokenIndexer = reader_for_mentions.token_indexer_returner()
        kbentity_loader = AllEntityCanonical_and_Defs_loader(
            args=opts,
            idx2cui=current_idx2cui,
            cui2cano=current_cui2cano,
            cui2def=current_cui2def,
            textfield_embedder=textfieldEmbedder,
            pretrained_tokenizer=Tokenizer,
            tokenindexer=TokenIndexer,
            canonical_and_def_connect_token=CANONICAL_AND_DEF_CONNECTTOKEN)
        Allentity_embedding_encodeIterator = InKBAllEntitiesEncoder(
            args=opts,
            entity_loader_datasetreaderclass=kbentity_loader,
            entity_encoder_wrapping_model=entity_encoder_wrapping_model,
            vocab=vocab)
        print('======2. Encoding All Entities=====')
        cuidx2encoded_emb = Allentity_embedding_encodeIterator.encoding_all_entities(
        )
        if opts.debug_for_entity_encoder:
            cuidx2encoded_emb = parse_cuidx2encoded_emb_for_debugging(
                cuidx2encoded_emb=cuidx2encoded_emb,
                original_cui2idx=current_cui2idx)
        cui2encoded_emb = parse_cuidx2encoded_emb_2_cui2emb(
            cuidx2encoded_emb=cuidx2encoded_emb,
            original_cui2idx=current_cui2idx)
        print('=====Encoding all entities in KB FINISHED!=====')

        print('\n+++++Indexnizing KB from encoded entites+++++')
        forstoring_encoded_entities_to_faiss = ForOnlyFaiss_KBIndexer(
            args=opts,
            input_cui2idx=current_cui2idx,
            input_idx2cui=current_idx2cui,
            input_cui2emb=cui2encoded_emb,
            search_method_for_faiss=opts.
            search_method_before_re_sorting_for_faiss,
            entity_emb_dim=768)
        print('+++++Indexnizing KB from encoded entites FINISHED!+++++')

        print('Loading BLINKBiencoder')
        blinkbiencoder_onlyfor_encodingmentions = BLINKBiencoder_OnlyforEncodingMentions(
            args=opts, mention_encoder=mention_encoder, vocab=vocab)
        blinkbiencoder_onlyfor_encodingmentions.cuda()
        blinkbiencoder_onlyfor_encodingmentions.eval()
        print('Loaded: BLINKBiencoder')

        print('Evaluation for BLINK start')
        blinkBiEncoderEvaluator = BLINKBiEncoderTopXRetriever(
            args=opts,
            vocab=vocab,
            blinkbiencoder_onlyfor_encodingmentions=
            blinkbiencoder_onlyfor_encodingmentions,
            fortrainigmodel_faiss_stored_kb_kgemb=
            forstoring_encoded_entities_to_faiss.indexed_faiss_returner(),
            reader_for_mentions=reader_for_mentions)
        finalblinkEvaluator = DevandTest_BLINKBiEncoder_IterateEvaluator(
            args=opts,
            blinkBiEncoderEvaluator=blinkBiEncoderEvaluator,
            experiment_logdir=experiment_logdir)
        finalblinkEvaluator.final_evaluation(dev_or_test_flag='dev')
        finalblinkEvaluator.final_evaluation(dev_or_test_flag='test')

        exp_end_time = time.time()
        print('Experiment time', math.floor(exp_end_time - exp_start_time),
              'sec')
Ejemplo n.º 3
0
def main():
    print("Loading biencoder.")
    opts = load_opts_from_json(parameters_json)
    cuda_devices = cuda_device_parser(str_ids=opts.cuda_devices)
    TRAIN_WORLDS, DEV_WORLDS, TEST_WORLDS = worlds_loader(args=opts)

    vocab = Vocabulary()
    iterator_for_training_and_evaluating_mentions = BucketIterator(
        batch_size=opts.batch_size_for_train,
        sorting_keys=[('context', 'num_tokens')])
    iterator_for_training_and_evaluating_mentions.index_with(vocab)

    embloader = EmbLoader(args=opts)
    emb_mapper, emb_dim, textfieldEmbedder = embloader.emb_returner()
    tokenIndexing = TokenIndexerReturner(args=opts)
    global_tokenizer = tokenIndexing.berttokenizer_returner()
    global_tokenIndexer = tokenIndexing.token_indexer_returner()

    if opts.load_from_checkpoint:
        mention_encoder, entity_encoder, model = load_model_objects(
            model_path=opts.model_path,
            mention_encoder_filename=opts.mention_encoder_filename,
            entity_encoder_filename=opts.entity_encoder_filename,
            model_filename=opts.model_filename)
        mention_encoder.share_memory()
        entity_encoder.share_memory()
        model.share_memory()
    else:
        mention_encoder = Pooler_for_mention(args=opts,
                                             word_embedder=textfieldEmbedder)
        entity_encoder = Pooler_for_title_and_desc(
            args=opts, word_embedder=textfieldEmbedder)
        model = Biencoder(args=opts,
                          mention_encoder=mention_encoder,
                          entity_encoder=entity_encoder,
                          vocab=vocab)
    model = model.cuda()

    with torch.no_grad():
        finalEvalFlag = 0
        world_name = 'wikipedia'
        dev_or_test_flag = 'test'

        reader_for_eval = WorldsReaderOnline(
            args=opts,
            world_name=world_name,
            token_indexers=global_tokenIndexer,
            tokenizer=global_tokenizer)
        Evaluator = Evaluate_one_world_raw_data(
            args=opts,
            world_name=world_name,
            reader=reader_for_eval,
            embedder=textfieldEmbedder,
            trainfinished_mention_encoder=mention_encoder,
            trainfinished_entity_encoder=entity_encoder,
            vocab=vocab,
            experiment_logdir=None,
            dev_or_test=dev_or_test_flag,
            berttokenizer=global_tokenizer,
            bertindexer=global_tokenIndexer)
        Evaluate_one_world_raw_data.finalEvalFlag = copy.copy(finalEvalFlag)

        return Evaluator