示例#1
0
    def _logging(
        self,
        model: EmmentalModel,
        dataloaders: List[EmmentalDataLoader],
        batch_size: int,
    ) -> Dict[str, float]:
        r"""Checking if it's time to evaluting or checkpointing.

        Args:
          model(EmmentalModel): The model to log.
          dataloaders(List[EmmentalDataLoader]): The data to evaluate.
          batch_size(int): Batch size.

        Returns:
          dict: The score dict.

        """

        # Switch to eval mode for evaluation
        model.eval()

        metric_dict = dict()

        self.logging_manager.update(batch_size)

        # Log the loss and lr
        metric_dict.update(self._aggregate_running_metrics(model))

        # Evaluate the model and log the metric
        trigger_evaluation = self.logging_manager.trigger_evaluation()
        if trigger_evaluation:

            # Log task specific metric
            metric_dict.update(
                self._evaluate(
                    model, dataloaders, Meta.config["learner_config"]["valid_split"]
                )
            )

            self.logging_manager.write_log(metric_dict)

            self._reset_losses()

        # Log metric dict every trigger evaluation time or full epoch
        if Meta.config["meta_config"]["verbose"] and (
            trigger_evaluation
            or self.logging_manager.epoch_total == int(self.logging_manager.epoch_total)
        ):
            logger.info(
                f"{self.logging_manager.counter_unit.capitalize()}: "
                f"{self.logging_manager.unit_total:.2f} {metric_dict}"
            )

        # Checkpoint the model
        if self.logging_manager.trigger_checkpointing():
            self.logging_manager.checkpoint_model(
                model, self.optimizer, self.lr_scheduler, metric_dict
            )

            self.logging_manager.write_log(metric_dict)

            self._reset_losses()

        # Switch to train mode
        model.train()

        return metric_dict
示例#2
0
class BootlegAnnotator(object):
    """BootlegAnnotator class: convenient wrapper of preprocessing and model
    eval to allow for annotating single sentences at a time for quick
    experimentation, e.g. in notebooks.

    Args:
        config: model config (default None)
        device: model device, -1 for CPU (default None)
        max_alias_len: maximum alias length (default 6)
        cand_map: alias candidate map (default None)
        threshold: probability threshold (default 0.0)
        cache_dir: cache directory (default None)
        model_name: model name (default None)
        verbose: verbose boolean (default False)
    """
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)

    def extract_mentions(self, text, label_func):
        """Wrapper function for mention extraction.

        Args:
            text: text to extract mentions from
            label_func: function that performs extraction (input is (text, alias trie, max alias length) ->
                        output is list of found aliases and found spans

        Returns: JSON object of sentence to be used in eval
        """
        found_aliases, found_spans = label_func(text, self.all_aliases_trie,
                                                self.max_alias_len)
        return {
            "sentence": text,
            "aliases": found_aliases,
            "spans": found_spans,
            # we don't know the true QID
            "qids": ["Q-1" for i in range(len(found_aliases))],
            "gold": [True for i in range(len(found_aliases))],
        }

    def set_threshold(self, value):
        """Sets threshold.

        Args:
            value: threshold value

        Returns:
        """
        self.threshold = value

    def label_mentions(self,
                       text_list,
                       label_func=find_aliases_in_sentence_tag):
        """Extracts mentions and runs disambiguation.

        Args:
            text_list: list of text to disambiguate (or single sentence)
            label_func: mention extraction funciton (optional)

        Returns: Dict of

            * ``qids``: final predicted QIDs,
            * ``probs``: final predicted probs,
            * ``titles``: final predicted titles,
            * ``cands``: all entity canddiates,
            * ``cand_probs``: probabilities of all candidates,
            * ``spans``: final extracted word spans,
            * ``aliases``: final extracted aliases,
        """
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert (type(text_list) is list and len(text_list) > 0
                    and type(text_list[0]) is str
                    ), f"We only accept inputs of strings and lists of strings"

        ebs = int(self.config.run_config.eval_batch_size)
        self.config.data_config.max_aliases = int(
            self.config.data_config.max_aliases)
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(
                enumerate(text_list),
                desc="Prepping data",
                total=len(text_list),
                disable=not self.verbose,
        ):
            sample = self.extract_mentions(text, label_func)
            total_start_exs += len(sample["aliases"])
            char_spans = self.get_char_spans(sample["spans"], text)

            final_char_spans.append(char_spans)

            (
                idxs_arr,
                aliases_to_predict_per_split,
                spans_arr,
                phrase_tokens_arr,
                pos_idxs,
            ) = sentence_utils.split_sentence(
                max_aliases=self.config.data_config.max_aliases,
                phrase=sample["sentence"],
                spans=sample["spans"],
                aliases=sample["aliases"],
                aliases_seen_by_model=list(range(len(sample["aliases"]))),
                seq_len=self.config.data_config.max_seq_len,
                is_bert=True,
                tokenizer=self.tokenizer,
            )
            aliases_arr = [[sample["aliases"][idx] for idx in idxs]
                           for idxs in idxs_arr]
            old_spans_arr = [[sample["spans"][idx] for idx in idxs]
                             for idxs in idxs_arr]
            qids_arr = [[sample["qids"][idx] for idx in idxs]
                        for idxs in idxs_arr]
            word_indices_arr = [
                self.tokenizer.convert_tokens_to_ids(pt)
                for pt in phrase_tokens_arr
            ]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert (
                    len(aliases_to_predict_arr) >= 0
                ), f"There are no aliases to predict for an example. This should not happen at this point."
                assert (
                    len(aliases_arr[sub_idx]) <=
                    self.config.data_config.max_aliases
                ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases."

                example_aliases = np.ones(
                    self.config.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_aliases_locs_end = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_alias_list_pos = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_true_entities = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][
                        mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.config.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one
                        # (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is
                        # reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(
                            alias_qids == qids_arr[sub_idx][mention_idx]
                        )[0][0] + (
                            not self.config.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][
                        mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence
                    # and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(
                    example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(
            batch_example_aliases_locs_start, device=self.torch_device)
        batch_example_aliases_locs_end = torch.tensor(
            batch_example_aliases_locs_end, device=self.torch_device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities,
                                                   device=self.torch_device)
        batch_word_indices = torch.tensor(batch_word_indices,
                                          device=self.torch_device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(
                range(0, batch_example_aliases.shape[0], ebs),
                desc="Evaluating model",
                disable=not self.verbose,
        ):
            start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs]
            end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            x_dict = self.get_forward_batch(start_span_idx, end_span_idx,
                                            word_indices, alias_indices)
            x_dict["guid"] = torch.arange(b_i,
                                          b_i + ebs,
                                          device=self.torch_device)

            (uid_bdict, _, prob_bdict, _) = self.model(  # type: ignore
                uids=x_dict["guid"],
                X_dict=x_dict,
                Y_dict=None,
                task_to_label_dict=self.task_to_label_dict,
                return_action_outputs=False,
            )
            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================
            # recover predictions
            probs = prob_bdict[NED_TASK]
            max_probs = probs.max(2)
            max_probs_indices = probs.argmax(2)
            for ex_i in range(probs.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(
                    self.config.data_config.train_in_candidates,
                    self.config.data_config.max_aliases,
                    self.entity_db.get_alias2qids(),
                    batch_aliases_arr[b_i + ex_i],
                )
                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].reshape(
                    self.config.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(
                        batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(
                                probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(
                                batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(
                                batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(
                                self.entity_db.get_title(pred_qid)
                                if pred_qid != "NC" else "NC")
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, (
            f"Something went wrong and we have predicted fewer mentions than extracted. "
            f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        )
        res_dict = {
            "qids": final_pred_cands,
            "probs": final_pred_probs,
            "titles": final_titles,
            "cands": final_all_cands,
            "cand_probs": final_cand_probs,
            "spans": final_spans,
            "aliases": final_aliases,
        }
        return res_dict

    def get_forward_batch(self, start_span_idx, end_span_idx, token_ids,
                          alias_idx):
        """Preps the forward batch for disambiguation.

        Args:
            start_span_idx: start span tensor
            end_span_idx: end span tensor
            token_ids: word token tensor
            alias_idx: alias index used for extracting candidate eids

        Returns: X_dict used in Emmental
        """
        entity_cand_eid = self.alias2cands(alias_idx).long()
        entity_cand_eid_mask = entity_cand_eid == -1
        entity_cand_eid_noneg = torch.where(
            entity_cand_eid >= 0,
            entity_cand_eid,
            (torch.ones_like(entity_cand_eid, dtype=torch.long) *
             (self.entity_db.num_entities_with_pad_and_nocand - 1)),
        )

        kg_prepped_embs = {}
        for emb_key in self.batch_on_the_fly_embs:
            kg_adj = self.batch_on_the_fly_embs[emb_key]["kg_adj"]
            prep_func = self.batch_on_the_fly_embs[emb_key][
                "kg_adj_process_func"]
            batch_prep = []
            for j in range(entity_cand_eid_noneg.shape[0]):
                batch_prep.append(
                    prep_func(entity_cand_eid_noneg[j].cpu(),
                              kg_adj).reshape(1, -1))
            kg_prepped_embs[emb_key] = torch.tensor(batch_prep,
                                                    device=self.torch_device)

        X_dict = {
            "guids": [],
            "start_span_idx": start_span_idx,
            "end_span_idx": end_span_idx,
            "token_ids": token_ids,
            "entity_cand_eid": entity_cand_eid_noneg,
            "entity_cand_eid_mask": entity_cand_eid_mask,
            "batch_on_the_fly_kg_adj": kg_prepped_embs,
        }
        return X_dict

    def get_char_spans(self, spans, text):
        """Helper function to get character spans instead of default word
        spans.

        Args:
            spans: word spans
            text: text

        Returns: character spans
        """
        query_toks = text.split()
        char_spans = []
        for span in spans:
            space_btwn_toks = (len(" ".join(query_toks[0:span[0] + 1])) -
                               len(" ".join(query_toks[0:span[0]])) -
                               len(query_toks[span[0]]))
            char_b = len(" ".join(query_toks[0:span[0]])) + space_btwn_toks
            char_e = char_b + len(" ".join(query_toks[span[0]:span[1]]))
            char_spans.append([char_b, char_e])
        return char_spans