Beispiel #1
0
 def setUp(self):
     entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
     self.entity_symbols = EntitySymbols.load_from_cache(
         entity_dump_dir, alias_cand_map_file="alias2qids.json"
     )
     self.config = {
         "data_config": {
             "train_in_candidates": False,
             "entity_dir": "test/data/entity_loader/entity_data",
             "entity_prep_dir": "prep",
             "alias_cand_map": "alias2qids.json",
             "max_aliases": 3,
             "data_dir": "test/data/entity_loader",
             "overwrite_preprocessed_data": True,
         },
         "run_config": {"distributed": False},
     }
def main():
    args = parse_args()
    logging.info(json.dumps(vars(args), indent=4))
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(args.data_dir, args.entity_symbols_dir))
    train_file = os.path.join(args.data_dir, args.train_file)
    save_dir = os.path.join(args.save_dir, "stats")
    logging.info(f"Will save data to {save_dir}")
    utils.ensure_dir(save_dir)
    # compute_histograms(save_dir, entity_symbols)
    compute_occurrences(
        save_dir,
        train_file,
        entity_symbols,
        args.lower,
        args.strip,
        num_workers=args.num_workers,
    )
def main():
    args = parse_args()
    print(ujson.dumps(vars(args), indent=4))
    entity_symbols = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, args.entity_map_dir),
        alias_cand_map_file=args.alias_cand_map,
        alias_idx_file=args.alias_idx_map,
    )
    print("DO LOWERCASE IS", "uncased" in args.bert_model)
    tokenizer = BertTokenizer.from_pretrained(
        args.bert_model,
        do_lower_case="uncased" in args.bert_model,
        cache_dir=args.word_model_cache,
    )
    model = BertModel.from_pretrained(
        args.bert_model,
        cache_dir=args.word_model_cache,
        output_attentions=False,
        output_hidden_states=False,
    )
    if not args.cpu:
        model = model.to("cuda")
    model.eval()

    entity2avgtitle = build_title_table(
        args.cpu, args.batch_size, model, tokenizer, entity_symbols
    )
    save_fold = os.path.dirname(args.save_file)
    if len(save_fold) > 0:
        if not os.path.exists(save_fold):
            os.makedirs(save_fold)
    if args.output_method == "pt":
        save_obj = (entity_symbols.get_qid2eid(), entity2avgtitle)
        torch.save(obj=save_obj, f=args.save_file)
    else:
        res = {}
        for qid in tqdm(entity_symbols.get_all_qids(), desc="Building final json"):
            eid = entity_symbols.get_eid(qid)
            res[qid] = entity2avgtitle[eid].tolist()
        with open(args.save_file, "w") as out_f:
            ujson.dump(res, out_f)
    print(f"Done!")
Beispiel #4
0
def main():
    args = parse_args()
    print(ujson.dumps(vars(args), indent=4))
    num_processes = min(args.processes, int(0.8 * multiprocessing.cpu_count()))
    print("Loading entity symbols")
    entity_symbols = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, args.entity_map_dir),
        alias_cand_map_file=args.alias_cand_map,
        alias_idx_file=args.alias_idx_map,
    )

    in_file = os.path.join(args.data_dir, args.train_file)
    print(f"Getting slice counts from {in_file}")
    qid_cnts = get_counts(num_processes, in_file)
    with open(os.path.join(args.data_dir, "qid_cnts_train.json"), "w") as out_f:
        ujson.dump(qid_cnts, out_f)
    df = build_reg_csv(qid_cnts, entity_symbols)

    df.to_csv(args.out_file, index=False)
    print(f"Saved file to {args.out_file}")
Beispiel #5
0
 def setUp(self):
     """ENTITY SYMBOLS
      {
        "multi word alias2":[["Q2",5.0],["Q1",3.0],["Q4",2.0]],
        "alias1":[["Q1",10.0],["Q4",6.0]],
        "alias3":[["Q1",30.0]],
        "alias4":[["Q4",20.0],["Q3",15.0],["Q2",1.0]]
      }
      EMBEDDINGS
      {
          "key": "learned",
          "freeze": false,
          "load_class": "LearnedEntityEmb",
          "args":
          {
            "learned_embedding_size": 10,
          }
      },
      {
         "key": "learned_type",
         "load_class": "LearnedTypeEmb",
         "freeze": false,
         "args": {
             "type_labels": "type_pred_mapping.json",
             "max_types": 1,
             "type_dim": 5,
             "merge_func": "addattn",
             "attn_hidden_size": 5
         }
     }
     """
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_model_training.json")
     self.entity_symbols = EntitySymbols.load_from_cache(
         os.path.join(self.args.data_config.entity_dir,
                      self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map,
     )
     emmental.init(log_dir="test/temp_log")
     if not os.path.exists(emmental.Meta.log_path):
         os.makedirs(emmental.Meta.log_path)
Beispiel #6
0
 def setUp(self):
     # tests that the sampling is done correctly on indices
     # load data from directory
     self.args = parser_utils.parse_boot_and_emm_args(
         "test/run_args/test_type_data.json")
     self.tokenizer = BertTokenizer.from_pretrained(
         "bert-base-cased",
         do_lower_case=False,
         cache_dir="test/data/emb_data/pretrained_bert_models",
     )
     self.is_bert = True
     self.entity_symbols = EntitySymbols.load_from_cache(
         os.path.join(self.args.data_config.entity_dir,
                      self.args.data_config.entity_map_dir),
         alias_cand_map_file=self.args.data_config.alias_cand_map,
     )
     self.temp_file_name = "test/data/data_loader/test_data.jsonl"
     self.guid_dtype = lambda max_aliases: np.dtype([
         ("sent_idx", "i8", 1),
         ("subsent_idx", "i8", 1),
         ("alias_orig_list_pos", "i8", max_aliases),
     ])
    def test_filter_qids(self):
        entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
        entity_db = EntitySymbols.load_from_cache(
            load_dir=entity_dump_dir, alias_cand_map_file="alias2qids.json"
        )
        qid2count = {"Q1": 10, "Q2": 20, "Q3": 2, "Q4": 4}
        perc_emb_drop = 0.8

        gold_qid2topk_eid = {"Q1": 2, "Q2": 1, "Q3": 2, "Q4": 2}
        gold_old2new_eid = {0: 0, -1: -1, 2: 1, 3: 2}
        gold_new_toes_eid = 2
        gold_num_topk_entities = 2

        (
            qid2topk_eid,
            old2new_eid,
            new_toes_eid,
            num_topk_entities,
        ) = filter_qids(perc_emb_drop, entity_db, qid2count)
        self.assertEqual(gold_qid2topk_eid, qid2topk_eid)
        self.assertEqual(gold_old2new_eid, old2new_eid)
        self.assertEqual(gold_new_toes_eid, new_toes_eid)
        self.assertEqual(gold_num_topk_entities, num_topk_entities)
Beispiel #8
0
    def test_create_entities(self):
        truealias2qids = {
            "alias1": [["Q1", 10.0], ["Q4", 6]],
            "multi word alias2": [["Q2", 5.0], ["Q1", 3], ["Q4", 2]],
            "alias3": [["Q1", 30.0]],
            "alias4": [["Q4", 20], ["Q3", 15.0], ["Q2", 1]],
        }

        trueqid2title = {
            "Q1": "alias1",
            "Q2": "multi alias2",
            "Q3": "word alias3",
            "Q4": "nonalias4",
        }

        # the non-candidate class is included in entity_dump
        trueqid2eid = {"Q1": 1, "Q2": 2, "Q3": 3, "Q4": 4}
        truealias2id = {"alias1": 0, "alias3": 1, "alias4": 2, "multi word alias2": 3}
        truealiastrie = {"multi word alias2": 0, "alias1": 1, "alias3": 2, "alias4": 3}

        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
        )
        tri_as_dict = {}
        for k in entity_symbols._alias_trie:
            tri_as_dict[k] = entity_symbols._alias_trie[k]

        self.assertEqual(entity_symbols.max_candidates, 3)
        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(tri_as_dict, truealiastrie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertIsNone(entity_symbols._qid2aliases)

        # Test load from dump
        temp_save_dir = "test/data/entity_loader_test"
        entity_symbols.save(temp_save_dir)
        entity_symbols = EntitySymbols.load_from_cache(temp_save_dir)

        self.assertEqual(entity_symbols.max_candidates, 3)
        self.assertEqual(entity_symbols.max_eid, 4)
        self.assertEqual(entity_symbols.max_alid, 3)
        self.assertDictEqual(entity_symbols._alias2qids, truealias2qids)
        self.assertDictEqual(entity_symbols._qid2title, trueqid2title)
        self.assertDictEqual(entity_symbols._qid2eid, trueqid2eid)
        self.assertDictEqual(tri_as_dict, truealiastrie)
        self.assertDictEqual(entity_symbols._alias2id, truealias2id)
        self.assertIsNone(entity_symbols._qid2aliases)
        shutil.rmtree(temp_save_dir)

        # Test edit mode
        entity_symbols = EntitySymbols(
            max_candidates=3,
            alias2qids=truealias2qids,
            qid2title=trueqid2title,
            edit_mode=True,
        )
        trueqid2aliases = {
            "Q1": {"alias1", "multi word alias2", "alias3"},
            "Q2": {"multi word alias2", "alias4"},
            "Q3": {"alias4"},
            "Q4": {"alias1", "multi word alias2", "alias4"},
        }

        self.assertDictEqual(entity_symbols._qid2aliases, trueqid2aliases)
Beispiel #9
0
def create_task(args, entity_symbols=None, slice_datasets=None):
    """Returns an EmmentalTask for named entity disambiguation (NED).

    Args:
        args: args
        entity_symbols: entity symbols (default None)
        slice_datasets: slice datasets used in scorer (default None)

    Returns: EmmentalTask for NED
    """

    if entity_symbols is None:
        entity_symbols = EntitySymbols.load_from_cache(
            load_dir=os.path.join(args.data_config.entity_dir,
                                  args.data_config.entity_map_dir),
            alias_cand_map_file=args.data_config.alias_cand_map,
            alias_idx_file=args.data_config.alias_idx_map,
        )

    # Create sentence encoder
    bert_model = BertEncoder(args.data_config.word_embedding,
                             output_size=args.model_config.hidden_size)

    # Gets the tasks that query for the individual embeddings (e.g., word, entity, type, kg)
    # The device dict will store which embedding modules we want on the cpu
    (
        embedding_task_flows,  # task flows for standard embeddings (e.g., kg, type, entity)
        embedding_module_pool,  # module for standard embeddings
        embedding_module_device_dict,  # module device dict for standard embeddings
        # some embeddings output indices for BERT so we handle these embeddings in our BERT layer
        # (see comments in get_through_bert_embedding_tasks)
        extra_bert_embedding_layers,
        embedding_payload_inputs,  # the layers that are fed into the payload
        embedding_total_sizes,  # total size of all embeddings
    ) = get_embedding_tasks(args, entity_symbols)

    # Add the extra embedding layers to BERT module
    for emb_obj in extra_bert_embedding_layers:
        bert_model.add_embedding(emb_obj)

    # Create the embedding payload, attention network, and prediction layer modules
    if args.model_config.attn_class == "BootlegM2E":
        embedding_payload = EmbeddingPayload(args, entity_symbols,
                                             embedding_total_sizes)
        attn_network = BootlegM2E(args, entity_symbols)
        pred_layer = PredictionLayer(args)

    elif args.model_config.attn_class == "Bootleg":
        embedding_payload = EmbeddingPayload(args, entity_symbols,
                                             embedding_total_sizes)
        attn_network = Bootleg(args, entity_symbols)
        pred_layer = PredictionLayer(args)

    elif args.model_config.attn_class == "BERTNED":
        # Baseline model
        embedding_payload = EmbeddingPayloadBase(args, entity_symbols,
                                                 embedding_total_sizes)
        attn_network = BERTNED(args, entity_symbols)
        pred_layer = NoopPredictionLayer(args)

    else:
        raise ValueError(f"{args.model_config.attn_class} is not supported.")

    sliced_scorer = BootlegSlicedScorer(args.data_config.train_in_candidates,
                                        slice_datasets)

    # Create module pool and combine with embedding module pool
    module_pool = nn.ModuleDict({
        BERT_MODEL_NAME: bert_model,
        "embedding_payload": embedding_payload,
        "attn_network": attn_network,
        PRED_LAYER: pred_layer,
    })
    module_pool.update(embedding_module_pool)

    # Create task flow
    task_flow = [
        {
            "name": BERT_MODEL_NAME,
            "module": BERT_MODEL_NAME,
            "inputs": [
                ("_input_", "entity_cand_eid"),
                ("_input_", "token_ids"),
            ],  # We pass the entity_cand_eids to BERT in case of embeddings that require word information
        },
        *embedding_task_flows,  # Add task flows to create embedding inputs
        {
            "name":
            "embedding_payload",
            "module":
            "embedding_payload",  # outputs: embedding_tensor
            "inputs": [
                ("_input_", "start_span_idx"),
                ("_input_", "end_span_idx"),
                *embedding_payload_inputs,  # all embeddings
            ],
        },
        {
            "name":
            "attn_network",
            "module":
            "attn_network",  # output: predictions from layers, output entity embeddings
            "inputs": [
                (BERT_MODEL_NAME, 0),  # sentence embedding
                (BERT_MODEL_NAME, 1),  # sentence embedding mask
                ("embedding_payload", 0),
                ("_input_", "entity_cand_eid_mask"),
                ("_input_", "start_span_idx"),
                ("_input_", "end_span_idx"),
                (
                    "_input_",
                    "batch_on_the_fly_kg_adj",
                ),  # special kg adjacency embedding prepped in dataloader
            ],
        },
        {
            "name":
            PRED_LAYER,
            "module":
            PRED_LAYER,
            "inputs": [
                (
                    "attn_network",
                    "intermed_scores",
                ),  # output predictions from intermediate layers from the model
                (
                    "attn_network",
                    "ent_embs",
                ),  # output entity embeddings (from all KG modules)
                (
                    "attn_network",
                    "final_scores",
                ),  # score (empty except for baseline model)
            ],
        },
    ]

    return EmmentalTask(
        name=NED_TASK,
        module_pool=module_pool,
        task_flow=task_flow,
        loss_func=disambig_loss,
        output_func=disambig_output,
        require_prob_for_eval=False,
        require_pred_for_eval=True,
        # action_outputs are used to stitch together sentence fragments
        action_outputs=[
            ("_input_", "sent_idx"),
            ("_input_", "subsent_idx"),
            ("_input_", "alias_orig_list_pos"),
            ("_input_", "for_dump_gold_cand_K_idx_train"),
            (PRED_LAYER, "ent_embs"),  # entity embeddings
        ],
        scorer=Scorer(customize_metric_funcs={
            f"{NED_TASK}_scorer": sliced_scorer.bootleg_score
        }),
        module_device=embedding_module_device_dict,
    )
def init_process(entity_dump_f):
    global ed_global
    ed_global = EntitySymbols.load_from_cache(load_dir=entity_dump_f)
Beispiel #11
0
def compress_topk_embeddings(args):
    assert 0 < args.perc_emb_drop < 1, f"perc_emb_drop must be between 0 and 1"
    print(
        f"Loading entity symbols from {os.path.join(args.entity_dir, 'entity_mappings')}"
    )
    entity_db = EntitySymbols.load_from_cache(
        os.path.join(args.entity_dir, "entity_mappings"))
    print(f"Loading qid2count from {args.qid2count}")
    qid2count = utils.load_json_file(args.qid2count)
    print(f"Filtering qids")
    (
        qid2topk_eid,
        old2new_eid,
        toes_eid,
        new_num_topk_entities,
    ) = filter_qids(args.perc_emb_drop, entity_db, qid2count)
    if len(args.model_path) > 0:
        assert (args.save_model_path is not None
                and len(args.save_model_path) > 0
                ), f"If you give a model path, you must give a save checkpoint"
        print(f"Filtering embeddings")
        state_dict, model_state_dict = load_statedict(args.model_path)
        try:
            get_nested_item(model_state_dict, ENTITY_EMB_KEYS)
        except:
            print(
                f"ERROR: All of {ENTITY_EMB_KEYS} are not in model_state_dict")
            raise
        model_state_dict = filter_embs(
            new_num_topk_entities,
            entity_db,
            old2new_eid,
            qid2topk_eid,
            toes_eid,
            model_state_dict,
        )
        # Generate the new old2new_eid weight vector to save in model_state_dict
        oldeid2topkeid = torch.arange(
            0, entity_db.num_entities_with_pad_and_nocand)
        # The +2 is to account for pads and unks. The -1 is as there are issues with -1 in the indexing
        # for entity embeddings. So we must manually make it the last entry
        oldeid2topkeid[-1] = new_num_topk_entities + 2 - 1
        for qid, new_eid in tqdm(qid2topk_eid.items(), desc="Setting new ids"):
            old_eid = entity_db.get_eid(qid)
            oldeid2topkeid[old_eid] = new_eid
        assert oldeid2topkeid[0] == 0, f"The 0 eid shouldn't be changed"
        assert (oldeid2topkeid[-1] == new_num_topk_entities + 2 -
                1), "The -1 eid should still map to the last row"
        model_state_dict = set_nested_item(model_state_dict, ENTITY_EID_KEYS,
                                           oldeid2topkeid)
        # Remove the eid2reg value as that was with the old entity id mapping
        try:
            model_state_dict = set_nested_item(model_state_dict,
                                               ENTITY_REG_KEYS, None)
        except:
            print(
                f"Could not remove regularization. If your model was trained with regularization mapping on "
                f"the learned entity embedding, this should not happen.")
        print(model_state_dict["module_pool"]["learned"].keys())
        state_dict["model"] = model_state_dict
        print(f"Saving model at {args.save_model_path}")
        torch.save(state_dict, args.save_model_path)
    print(
        f"Saving topk to eid at {os.path.join(args.entity_dir, 'entity_mappings', args.save_qid2topk_file)}"
    )
    utils.dump_json_file(
        os.path.join(args.entity_dir, "entity_mappings",
                     args.save_qid2topk_file),
        qid2topk_eid,
    )

    if args.model_config is not None:
        modify_config(
            args.model_config,
            args.save_model_config,
            args.save_model_path,
            os.path.join(args.entity_dir, "entity_mappings",
                         args.save_qid2topk_file),
            args.perc_emb_drop,
        )
Beispiel #12
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file
    def test_filter_embs(self):
        entity_dump_dir = "test/data/entity_loader/entity_data/entity_mappings"
        entity_db = EntitySymbols.load_from_cache(
            load_dir=entity_dump_dir,
            alias_cand_map_file="alias2qids.json",
            alias_idx_file="alias2id.json",
        )
        num_topk_entities = 2
        old2new_eid = {0: 0, -1: -1, 2: 1, 3: 2}
        qid2topk_eid = {"Q1": 2, "Q2": 1, "Q3": 2, "Q4": 2}
        toes_eid = 2
        state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight": torch.Tensor(
                        [
                            [1.0, 2, 3, 4, 5],
                            [2, 2, 2, 2, 2],
                            [3, 3, 3, 3, 3],
                            [4, 4, 4, 4, 4],
                            [5, 5, 5, 5, 5],
                            [0, 0, 0, 0, 0],
                        ]
                    )
                }
            }
        }

        gold_state_dict = {
            "module_pool": {
                "learned": {
                    "learned_entity_embedding.weight": torch.Tensor(
                        [
                            [1.0, 2, 3, 4, 5],
                            [3, 3, 3, 3, 3],
                            [4, 4, 4, 4, 4],
                            [0, 0, 0, 0, 0],
                        ]
                    )
                }
            }
        }

        new_state_dict = filter_embs(
            num_topk_entities,
            entity_db,
            old2new_eid,
            qid2topk_eid,
            toes_eid,
            state_dict,
        )
        gld = gold_state_dict
        nsd = new_state_dict
        keys_to_check = ["module_pool", "learned", "learned_entity_embedding.weight"]
        for k in keys_to_check:
            assert k in nsd
            assert k in gld
            if type(gld[k]) is dict:
                gld = gld[k]
                nsd = nsd[k]
                continue
            else:
                assert torch.equal(nsd[k], gld[k])
Beispiel #14
0
    def load_from_cache(
        cls,
        load_dir,
        edit_mode=False,
        verbose=False,
        no_kg=False,
        no_type=False,
        type_systems_to_load=None,
    ):
        """Loaded a pre-saved profile.

        Args:
            load_dir: load directory
            edit_mode: edit mode flag, default False
            verbose: verbose flag, default False
            no_kg: load kg or not flag, default False
            no_type: load types or not flag, default False. If True, this will ignore type_systems_to_load.
            type_systems_to_load: list of type systems to load, default is None which means all types systems

        Returns: entity profile object
        """
        # Check type system input
        load_dir = Path(load_dir)
        type_subfolder = load_dir / TYPE_SUBFOLDER
        if type_systems_to_load is not None:
            if not isinstance(type_systems_to_load, list):
                raise ValueError(
                    f"`type_systems` must be a list of subfolders in {type_subfolder}"
                )
            for sys in type_systems_to_load:
                if sys not in list([p.name for p in type_subfolder.iterdir()]):
                    raise ValueError(
                        f"`type_systems` must be a list of subfolders in {type_subfolder}. {sys} is not one."
                    )

        if verbose:
            print("Loading Entity Symbols")
        entity_symbols = EntitySymbols.load_from_cache(
            load_dir / ENTITY_SUBFOLDER,
            edit_mode=edit_mode,
            verbose=verbose,
        )
        if no_type:
            print(
                f"Not loading type information. We will act as if there is no types associated with any entity "
                f"and will not modify the types in any way, even if calling `add`."
            )
        type_sys_dict = {}
        for fold in type_subfolder.iterdir():
            if ((not no_type) and (type_systems_to_load is None
                                   or fold.name in type_systems_to_load)
                    and (fold.is_dir())):
                if verbose:
                    print(f"Loading Type Symbols from {fold}")
                type_sys_dict[fold.name] = TypeSymbols.load_from_cache(
                    type_subfolder / fold.name,
                    edit_mode=edit_mode,
                    verbose=verbose,
                )
        if verbose:
            print(f"Loading KG Symbols")
        if no_kg:
            print(
                f"Not loading KG information. We will act as if there is not KG connections between entities. "
                f"We will not modify the KG information in any way, even if calling `add`."
            )
        kg_symbols = None
        if not no_kg:
            kg_symbols = KGSymbols.load_from_cache(
                load_dir / KG_SUBFOLDER,
                edit_mode=edit_mode,
                verbose=verbose,
            )
        return cls(entity_symbols, type_sys_dict, kg_symbols, edit_mode,
                   verbose)
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)
Beispiel #16
0
def create_task(args, entity_symbols=None, slice_datasets=None):
    """Creates a type prediction task.

    Args:
        args: args
        entity_symbols: entity symbols
        slice_datasets: slice datasets used in scorer (default None)

    Returns: EmmentalTask for type prediction
    """
    if entity_symbols is None:
        entity_symbols = EntitySymbols.load_from_cache(
            load_dir=os.path.join(
                args.data_config.entity_dir, args.data_config.entity_map_dir
            ),
            alias_cand_map_file=args.data_config.alias_cand_map,
            alias_idx_file=args.data_config.alias_idx_map,
        )

    # Create sentence encoder
    bert_model = BertEncoder(
        args.data_config.word_embedding, output_size=args.model_config.hidden_size
    )

    # Create type prediction module
    # Add 1 for pad type
    type_prediction = TypePred(
        args.model_config.hidden_size,
        args.data_config.type_prediction.dim,
        args.data_config.type_prediction.num_types + 1,
        embedding_utils.get_max_candidates(entity_symbols, args.data_config),
    )

    # Create scorer
    sliced_scorer = BootlegSlicedScorer(
        args.data_config.train_in_candidates, slice_datasets
    )

    # Create module pool
    # BERT model will be shared across tasks as long as the name matches
    module_pool = nn.ModuleDict(
        {BERT_MODEL_NAME: bert_model, "type_prediction": type_prediction}
    )

    # Create task flow
    task_flow = [
        {
            "name": BERT_MODEL_NAME,
            "module": BERT_MODEL_NAME,
            "inputs": [
                ("_input_", "entity_cand_eid"),
                ("_input_", "token_ids"),
            ],  # We pass the entity_cand_eids to BERT in case of embeddings that require word information
        },
        {
            "name": "type_prediction",
            "module": "type_prediction",  # output: embedding_dict, batch_type_pred
            "inputs": [
                (BERT_MODEL_NAME, 0),  # sentence embedding
                ("_input_", "start_span_idx"),
            ],
        },
    ]

    return EmmentalTask(
        name=TYPE_PRED_TASK,
        module_pool=module_pool,
        task_flow=task_flow,
        loss_func=partial(type_loss, "type_prediction"),
        output_func=partial(type_output, "type_prediction"),
        require_prob_for_eval=False,
        require_pred_for_eval=True,
        scorer=Scorer(
            customize_metric_funcs={
                f"{TYPE_PRED_TASK}_scorer": sliced_scorer.type_pred_score
            }
        ),
    )
def main():
    gl_start = time.time()
    multiprocessing.set_start_method("spawn")
    args = get_arg_parser().parse_args()
    print(json.dumps(vars(args), indent=4))
    utils.ensure_dir(args.data_dir)

    out_dir = os.path.join(args.data_dir, args.out_subdir)
    if os.path.exists(out_dir):
        shutil.rmtree(out_dir)
    os.makedirs(out_dir, exist_ok=True)
    # Reading in files
    in_files_train = glob.glob(os.path.join(args.data_dir, "*.jsonl"))
    in_files_cand = glob.glob(
        os.path.join(args.contextual_cand_data, "*.jsonl"))
    assert len(in_files_train
               ) > 0, f"We didn't find any train files at {args.data_dir}"
    assert (
        len(in_files_cand) > 0
    ), f"We didn't find any contextual files at {args.contextual_cand_data}"
    in_files = []
    for file in in_files_train:
        file_name = os.path.basename(file)
        tag = os.path.splitext(file_name)[0]
        is_train = "train" in tag
        if is_train:
            print(
                f"{file_name} is a training dataset...will be processed as such"
            )
        pair = None
        for f in in_files_cand:
            if tag in f:
                pair = f
                break
        assert pair is not None, f"{file_name} name, {tag} tag"
        out_file = os.path.join(out_dir, file_name)
        in_files.append([file, pair, out_file, is_train])
    final_cand_map = {}
    max_cands = 0
    for pair in in_files:
        print(
            f"Reading in {pair[0]} with cand maps {pair[1]} and dumping to {pair[2]}"
        )
        new_alias2qids = merge_data(
            args.processes,
            args.train_in_candidates,
            args.keep_orig,
            args.max_candidates,
            pair,
            args.entity_dump,
        )
        for al in new_alias2qids:
            assert al not in final_cand_map, f"{al} is already in final_cand_map"
            final_cand_map[al] = new_alias2qids[al]
            max_cands = max(max_cands, len(final_cand_map[al]))

    print(f"Buidling new entity symbols")
    entity_dump = EntitySymbols.load_from_cache(load_dir=args.entity_dump)
    entity_dump_new = EntitySymbols(
        max_candidates=max_cands,
        alias2qids=final_cand_map,
        qid2title=entity_dump.get_qid2title(),
    )
    out_dir = os.path.join(out_dir, "entity_db/entity_mappings")
    entity_dump_new.save(out_dir)
    print(f"Finished in {time.time() - gl_start}s")