Пример #1
0
    def test_edge_case(self):
        # Edge-case lengths

        # Test maximum sequence length
        max_aliases = 30
        max_seq_len = 3

        # Manual data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big alias1", "multi word alias2 and alias3"]
        aliases_to_predict = [0, 1]
        spans = [[0, 3], [8, 13]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The big alias1".split(), "multi word alias2".split()
        ]
        true_spans_arr = [[[0, 3]], [[0, 5]]]
        true_alias_to_predict_arr = [[0], [0]]
        true_aliases_arr = [["The big alias1"],
                            ["multi word alias2 and alias3"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #2
0
    def test_split_sentence_alias_to_predict(self):
        # No splitting nut change in aliases to predict...nothing should change
        max_aliases = 30
        max_seq_len = 24

        # Manually created data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big", "alias3", "alias5"]
        aliases_to_predict = [0, 1]
        spans = [[0, 2], [12, 13], [20, 21]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # Truth data
        true_phrase_arr = [
            "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>"
            .split(" ")
        ]
        true_spans_arr = [[[0, 2], [12, 13], [20, 21]]]
        true_alias_to_predict_arr = [[0, 1]]
        true_aliases_arr = [["The big", "alias3", "alias5"]]

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #3
0
    def test_split_sentence_max_aliases(self):
        # Test if the sentence splits correctly when max_aliases is less than the number of aliases
        max_aliases = 2
        max_seq_len = 24

        # Manually created data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big", "alias3", "alias5"]
        aliases_to_predict = [0, 1, 2]
        spans = [[0, 2], [12, 13], [20, 21]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5 <pad> <pad> <pad>"
            .split(" ")
        ] * 2
        true_spans_arr = [[[0, 2], [12, 13]], [[20, 21]]]
        true_alias_to_predict_arr = [[0, 1], [0]]
        true_aliases_arr = [["The big", "alias3"], ["alias5"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #4
0
    def label_mentions(self, text):
        sample = self.extract_mentions(text)
        idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr = sentence_utils.split_sentence(
            max_aliases=self.args.data_config.max_aliases,
            phrase=sample['sentence'],
            spans=sample['spans'],
            aliases=sample['aliases'],
            aliases_seen_by_model=[i for i in range(len(sample['aliases']))],
            seq_len=self.args.data_config.max_word_token_len,
            word_symbols=self.word_db)
        aliases_arr = [[sample['aliases'][idx] for idx in idxs]
                       for idxs in idxs_arr]
        qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr]
        word_indices_arr = [
            self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr
        ]

        if len(idxs_arr) > 1:
            #TODO: support sentences that overflow due to long sequence length or too many mentions
            raise ValueError(
                'Overflowing sentences not currently supported in Annotator')

        # iterate over each sample in the split
        for sub_idx in range(len(idxs_arr)):
            example_aliases = np.ones(self.args.data_config.max_aliases,
                                      dtype=np.int) * PAD_ID
            example_true_entities = np.ones(
                self.args.data_config.max_aliases) * PAD_ID
            example_aliases_locs_start = np.ones(
                self.args.data_config.max_aliases) * PAD_ID
            example_aliases_locs_end = np.ones(
                self.args.data_config.max_aliases) * PAD_ID

            aliases = aliases_arr[sub_idx]
            for mention_idx, alias in enumerate(aliases):
                # get aliases
                alias_trie_idx = self.entity_db.get_alias_idx(alias)
                alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                example_aliases[mention_idx] = alias_trie_idx

                # alias_idx_pair
                span_idx = spans_arr[sub_idx][mention_idx]
                span_start_idx, span_end_idx = span_idx
                example_aliases_locs_start[mention_idx] = span_start_idx
                example_aliases_locs_end[mention_idx] = span_end_idx

            # get word indices
            word_indices = word_indices_arr[sub_idx]

            # entity indices from alias table (these are the candidates)
            entity_indices = self.alias_table(example_aliases)

            # all CPU embs have to retrieved on the fly
            batch_on_the_fly_data = {}
            for emb_name, emb in self.batch_on_the_fly_embs.items():
                batch_on_the_fly_data[emb_name] = torch.tensor(
                    emb.batch_prep(example_aliases, entity_indices),
                    device=self.device)

            outs, entity_pack, _ = self.model(
                alias_idx_pair_sent=[
                    torch.tensor(example_aliases_locs_start,
                                 device=self.device).unsqueeze(0),
                    torch.tensor(example_aliases_locs_end,
                                 device=self.device).unsqueeze(0)
                ],
                word_indices=torch.tensor([word_indices], device=self.device),
                alias_indices=torch.tensor(example_aliases,
                                           device=self.device).unsqueeze(0),
                entity_indices=torch.tensor(entity_indices,
                                            device=self.device).unsqueeze(0),
                batch_prepped_data={},
                batch_on_the_fly_data=batch_on_the_fly_data)

            entity_cands = eval_utils.map_aliases_to_candidates(
                self.args.data_config.train_in_candidates, self.entity_db,
                aliases)
            # recover predictions
            probs = torch.exp(
                eval_utils.masked_class_logsoftmax(
                    pred=outs[DISAMBIG][FINAL_LOSS],
                    mask=~entity_pack.mask,
                    dim=2))
            max_probs, max_probs_indices = probs.max(2)

            pred_cands = []
            pred_probs = []
            titles = []
            for alias_idx in range(len(aliases)):
                pred_idx = max_probs_indices[0][alias_idx]
                pred_prob = max_probs[0][alias_idx].item()
                pred_qid = entity_cands[alias_idx][pred_idx]
                if pred_prob > self.threshold:
                    pred_cands.append(pred_qid)
                    pred_probs.append(pred_prob)
                    titles.append(
                        self.entity_db.
                        get_title(pred_qid) if pred_qid != 'NC' else 'NC')

            return pred_cands, pred_probs, titles
Пример #5
0
    def label_mentions(self,
                       text_list,
                       label_func=find_aliases_in_sentence_tag):
        """Extracts mentions and runs disambiguation.

        Args:
            text_list: list of text to disambiguate (or single sentence)
            label_func: mention extraction funciton (optional)

        Returns: Dict of

            * ``qids``: final predicted QIDs,
            * ``probs``: final predicted probs,
            * ``titles``: final predicted titles,
            * ``cands``: all entity canddiates,
            * ``cand_probs``: probabilities of all candidates,
            * ``spans``: final extracted word spans,
            * ``aliases``: final extracted aliases,
        """
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert (type(text_list) is list and len(text_list) > 0
                    and type(text_list[0]) is str
                    ), f"We only accept inputs of strings and lists of strings"

        ebs = int(self.config.run_config.eval_batch_size)
        self.config.data_config.max_aliases = int(
            self.config.data_config.max_aliases)
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(
                enumerate(text_list),
                desc="Prepping data",
                total=len(text_list),
                disable=not self.verbose,
        ):
            sample = self.extract_mentions(text, label_func)
            total_start_exs += len(sample["aliases"])
            char_spans = self.get_char_spans(sample["spans"], text)

            final_char_spans.append(char_spans)

            (
                idxs_arr,
                aliases_to_predict_per_split,
                spans_arr,
                phrase_tokens_arr,
                pos_idxs,
            ) = sentence_utils.split_sentence(
                max_aliases=self.config.data_config.max_aliases,
                phrase=sample["sentence"],
                spans=sample["spans"],
                aliases=sample["aliases"],
                aliases_seen_by_model=list(range(len(sample["aliases"]))),
                seq_len=self.config.data_config.max_seq_len,
                is_bert=True,
                tokenizer=self.tokenizer,
            )
            aliases_arr = [[sample["aliases"][idx] for idx in idxs]
                           for idxs in idxs_arr]
            old_spans_arr = [[sample["spans"][idx] for idx in idxs]
                             for idxs in idxs_arr]
            qids_arr = [[sample["qids"][idx] for idx in idxs]
                        for idxs in idxs_arr]
            word_indices_arr = [
                self.tokenizer.convert_tokens_to_ids(pt)
                for pt in phrase_tokens_arr
            ]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert (
                    len(aliases_to_predict_arr) >= 0
                ), f"There are no aliases to predict for an example. This should not happen at this point."
                assert (
                    len(aliases_arr[sub_idx]) <=
                    self.config.data_config.max_aliases
                ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases."

                example_aliases = np.ones(
                    self.config.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_aliases_locs_end = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_alias_list_pos = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_true_entities = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][
                        mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.config.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one
                        # (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is
                        # reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(
                            alias_qids == qids_arr[sub_idx][mention_idx]
                        )[0][0] + (
                            not self.config.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][
                        mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence
                    # and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(
                    example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(
            batch_example_aliases_locs_start, device=self.torch_device)
        batch_example_aliases_locs_end = torch.tensor(
            batch_example_aliases_locs_end, device=self.torch_device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities,
                                                   device=self.torch_device)
        batch_word_indices = torch.tensor(batch_word_indices,
                                          device=self.torch_device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(
                range(0, batch_example_aliases.shape[0], ebs),
                desc="Evaluating model",
                disable=not self.verbose,
        ):
            start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs]
            end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            x_dict = self.get_forward_batch(start_span_idx, end_span_idx,
                                            word_indices, alias_indices)
            x_dict["guid"] = torch.arange(b_i,
                                          b_i + ebs,
                                          device=self.torch_device)

            (uid_bdict, _, prob_bdict, _) = self.model(  # type: ignore
                uids=x_dict["guid"],
                X_dict=x_dict,
                Y_dict=None,
                task_to_label_dict=self.task_to_label_dict,
                return_action_outputs=False,
            )
            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================
            # recover predictions
            probs = prob_bdict[NED_TASK]
            max_probs = probs.max(2)
            max_probs_indices = probs.argmax(2)
            for ex_i in range(probs.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(
                    self.config.data_config.train_in_candidates,
                    self.config.data_config.max_aliases,
                    self.entity_db.get_alias2qids(),
                    batch_aliases_arr[b_i + ex_i],
                )
                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].reshape(
                    self.config.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(
                        batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(
                                probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(
                                batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(
                                batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(
                                self.entity_db.get_title(pred_qid)
                                if pred_qid != "NC" else "NC")
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, (
            f"Something went wrong and we have predicted fewer mentions than extracted. "
            f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        )
        res_dict = {
            "qids": final_pred_cands,
            "probs": final_pred_probs,
            "titles": final_titles,
            "cands": final_all_cands,
            "cand_probs": final_cand_probs,
            "spans": final_spans,
            "aliases": final_aliases,
        }
        return res_dict
Пример #6
0
    def test_real_cases_bert(self):
        # Example 1
        max_aliases = 10
        max_seq_len = 100

        # Manual data
        sentence = "The guest roster for O'Brien 's final show on January 22\u2014 Tom Hanks , Steve Carell and original first guest Will Ferrell \u2014was regarded by O'Brien as a `` dream lineup '' ; in addition , Neil Young performed his song `` Long May You Run `` and , as the show closed , was joined by Beck , Ferrell ( dressed as Ronnie Van Zant ) , Billy Gibbons , Ben Harper , O'Brien , Viveca Paulin , and The Tonight Show Band to perform the Lynyrd Skynyrd song `` Free Bird `` ."
        aliases = [
            "tom hanks", "steve carell", "will ferrell", "neil young",
            "long may you run", "beck", "ronnie van zant", "billy gibbons",
            "ben harper", "viveca paulin", "lynyrd skynyrd", "free bird"
        ]
        spans = [[11, 13], [14, 16], [20, 22], [36, 38], [42, 46], [57, 58],
                 [63, 66], [68, 70], [71, 73], [76, 78], [87, 89], [91, 93]]
        aliases_to_predict = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]

        # Truth
        true_phrase_arr = [
            [
                '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien',
                "'", 's', 'final', 'show', 'on', 'January', '22', '—', 'Tom',
                'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and', 'original',
                'first', 'guest', 'Will', 'Fe', '##rrell', '—', 'was',
                'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`', '`',
                'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',',
                'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long',
                'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show',
                'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe',
                '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant',
                ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O',
                "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',',
                '[SEP]'
            ],
            [
                '[CLS]', 'The', 'guest', 'roster', 'for', 'O', "'", 'Brien',
                "'", 's', 'final', 'show', 'on', 'January', '22',
                '—', 'Tom', 'Hank', '##s', ',', 'Steve', 'Care', '##ll', 'and',
                'original', 'first', 'guest', 'Will', 'Fe', '##rrell', '—',
                'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a', '`',
                '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition', ',',
                'Neil', 'Young', 'performed', 'his', 'song', '`', '`', 'Long',
                'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the', 'show',
                'closed', ',', 'was', 'joined', 'by', 'Beck', ',', 'Fe',
                '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z', '##ant',
                ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper', ',', 'O',
                "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul', '##in', ',',
                '[SEP]'
            ],
            [
                '[CLS]', 'original', 'first', 'guest', 'Will', 'Fe', '##rrell',
                '—', 'was', 'regarded', 'by', 'O', "'", 'Brien', 'as', 'a',
                '`', '`', 'dream', 'lineup', "'", "'", ';', 'in', 'addition',
                ',', 'Neil', 'Young', 'performed', 'his', 'song', '`', '`',
                'Long', 'May', 'You', 'Run', '`', '`', 'and', ',', 'as', 'the',
                'show', 'closed', ',', 'was', 'joined', 'by', 'Beck', ',',
                'Fe', '##rrell', '(', 'dressed', 'as', 'Ronnie', 'Van', 'Z',
                '##ant', ')', ',', 'Billy', 'Gibbons', ',', 'Ben', 'Harper',
                ',', 'O', "'", 'Brien', ',', 'V', '##ive', '##ca', 'Paul',
                '##in', ',', 'and', 'The', 'Tonight', 'Show', 'Band', 'to',
                'perform', 'the', 'L', '##yn', '##yr', '##d', 'Sky', '##ny',
                '##rd', 'song', '`', '`', 'Free', 'Bird', '`', '`', '.',
                '[SEP]'
            ]
        ]
        true_spans_arr = [[[12, 14], [15, 17], [23, 25], [40, 42], [46, 50],
                           [61, 62], [67, 70], [72, 74]],
                          [[46, 50], [61, 62], [67, 70], [72, 74], [76, 78],
                           [81, 84], [93, 95], [100, 102]],
                          [[17, 19], [23, 27], [38, 39], [44, 47], [49, 51],
                           [53, 55], [58, 61], [70, 72], [77, 79]]]
        true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5], [2, 3, 4, 5], [7, 8]]
        true_aliases_arr = [[
            "tom hanks", "steve carell", "will ferrell", "neil young",
            "long may you run", "beck", "ronnie van zant", "billy gibbons"
        ],
                            [
                                "long may you run", "beck", "ronnie van zant",
                                "billy gibbons", "ben harper", "viveca paulin",
                                "lynyrd skynyrd", "free bird"
                            ],
                            [
                                "neil young", "long may you run", "beck",
                                "ronnie van zant", "billy gibbons",
                                "ben harper", "viveca paulin",
                                "lynyrd skynyrd", "free bird"
                            ]]
        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)
        assert len(idxs_arr) == 3
        assert len(aliases_to_predict_arr) == 3
        assert len(spans_arr) == 3
        assert len(phrase_tokens_arr) == 3

        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #7
0
    def test_split_sentence_bert(self):

        # Example 1
        max_aliases = 30
        max_seq_len = 20

        # Manual data
        sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5'
        aliases = ["Kittens love", "alias2", "alias5"]
        spans = [[0, 2], [5, 6], [10, 11]]
        aliases_to_predict = [0, 1, 2]

        # Truth
        bert_tokenized = [
            'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp', '##pet',
            '##eers', 'because', 'alias', '##2', 'and', 'spanning', 'the',
            'br', '##rea', '##ches', 'alias', '##5'
        ]
        true_phrase_arr = [['[CLS]'] + bert_tokenized + ['[SEP]']]
        true_spans_arr = [[[1, 4], [11, 13], [19, 21]]]
        true_alias_to_predict_arr = [[0, 1, 2]]
        true_aliases_arr = [["Kittens love", "alias2", "alias5"]]

        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1

        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Example 2
        max_aliases = 30
        max_seq_len = 7

        # Manual data
        sentence = 'Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5'
        aliases = ["Kittens love", "alias2", "alias5"]
        spans = [[0, 2], [5, 6], [10, 11]]
        aliases_to_predict = [0, 1, 2]

        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # Truth
        true_phrase_arr = [[
            '[CLS]', 'Kit', '##tens', 'love', 'purple', '##ish', 'pu', '##pp',
            '[SEP]'
        ],
                           [
                               '[CLS]', '##eers', 'because', 'alias', '##2',
                               'and', 'spanning', 'the', '[SEP]'
                           ],
                           [
                               '[CLS]', 'spanning', 'the', 'br', '##rea',
                               '##ches', 'alias', '##5', '[SEP]'
                           ]]
        true_spans_arr = [[[1, 4]], [[3, 5]], [[6, 8]]]
        true_alias_to_predict_arr = [[0], [0], [0]]
        true_aliases_arr = [["Kittens love"], ["alias2"], ["alias5"]]

        assert len(idxs_arr) == 3
        assert len(aliases_to_predict_arr) == 3
        assert len(spans_arr) == 3
        assert len(phrase_tokens_arr) == 3
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Example 3: Test greedy nature of algorithm. It will greedily pack the first two aliases together and the last alias will be split up even though the second alias is also in the second split.
        max_aliases = 30
        max_seq_len = 18

        # Manual data
        sentence = 'Kittens Kittens Kittens Kittens love purpleish pupppeteers because alias2 and spanning the brreaches alias5'
        aliases = ["Kittens love", "alias2", "alias5"]
        spans = [[3, 5], [8, 9], [13, 14]]
        aliases_to_predict = [0, 1, 2]

        # Run function
        args = parser_utils.get_full_config(
            "test/run_args/test_data_bert.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [[
            '[CLS]', '##tens', 'Kit', '##tens', 'Kit', '##tens', 'love',
            'purple', '##ish', 'pu', '##pp', '##pet', '##eers', 'because',
            'alias', '##2', 'and', 'spanning', 'the', '[SEP]'
        ],
                           [
                               '[CLS]', 'love', 'purple', '##ish', 'pu',
                               '##pp', '##pet', '##eers', 'because', 'alias',
                               '##2', 'and', 'spanning', 'the', 'br', '##rea',
                               '##ches', 'alias', '##5', '[SEP]'
                           ]]
        true_spans_arr = [[[4, 7], [14, 16]], [[9, 11], [17, 19]]]
        true_alias_to_predict_arr = [[0, 1], [1]]
        true_aliases_arr = [["Kittens love", "alias2"], ["alias2", "alias5"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len + 2)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #8
0
    def test_real_cases(self):
        # Real examples we messed up

        # EXAMPLE 1
        max_aliases = 30
        max_seq_len = 50

        # 3114|0~*~1~*~2~*~3~*~4~*~5|mexico~*~panama~*~ecuador~*~peru~*~bolivia~*~colombia|3966054~*~22997~*~9334~*~170691~*~3462~*~5222|19:20~*~36:37~*~39:40~*~44:45~*~48:49~*~70:71|The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia
        sentence = 'The animal is called paca in most of its range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia'
        aliases = [
            "mexico", "panama", "ecuador", "peru", "bolivia", "colombia"
        ]
        aliases_to_predict = [0, 1, 2, 3, 4, 5]
        spans = [[19, 20], [36, 37], [39, 40], [44, 45], [48, 49], [70, 71]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            'range but tepezcuintle original Aztec language name in most of Mexico and Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media'
            .split(),
            'Central America pisquinte in northern Costa Rica jaleb in the Yucatán peninsula conejo pintado in Panama guanta in Ecuador majás or picuro in Peru jochi pintado in Bolivia and boruga tinajo Fauna y flora de la cuenca media del Río Lebrija en Rionegro Santander Humboldt Institute or guartinaja in Colombia'
            .split()
        ]
        true_spans_arr = [[[10, 11], [27, 28], [30, 31], [35, 36], [39, 40]],
                          [[15, 16], [18, 19], [23, 24], [27, 28], [49, 50]]]
        true_alias_to_predict_arr = [[0, 1, 2, 3, 4], [4]]
        true_aliases_arr = [["mexico", "panama", "ecuador", "peru", "bolivia"],
                            [
                                "panama", "ecuador", "peru", "bolivia",
                                "colombia"
                            ]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # EXAMPLE 2
        max_aliases = 10
        max_seq_len = 50

        # 20|0~*~1~*~2~*~3~*~4~*~5~*~6~*~7~*~8~*~9~*~10~*~11~*~12~*~13~*~14~*~15~*~16~*~17~*~18~*~19~*~20|coolock~*~swords~*~darndale~*~santry~*~donnycarney~*~baldoyle~*~sutton~*~donaghmede~*~artane~*~whitehall~*~kilbarrack~*~raheny~*~clontarf~*~fairview~*~malahide~*~howth~*~marino~*~ballybough~*~north strand~*~sheriff street~*~east wall|1037463~*~182210~*~8554720~*~2432965~*~7890942~*~1223621~*~1008011~*~3698049~*~1469895~*~2144656~*~3628425~*~1108214~*~1564212~*~1438118~*~944694~*~1037467~*~5745962~*~2436385~*~5310245~*~12170199~*~2814197|12:13~*~14:15~*~15:16~*~17:18~*~18:19~*~19:20~*~20:21~*~21:22~*~22:23~*~23:24~*~24:25~*~25:26~*~26:27~*~27:28~*~28:29~*~29:30~*~30:31~*~38:39~*~39:41~*~41:43~*~43:45|East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall
        sentence = "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall"
        aliases = [
            "coolock", "swords", "darndale", "santry", "donnycarney",
            "baldoyle", "sutton", "donaghmede", "artane", "whitehall",
            "kilbarrack", "raheny", "clontarf", "fairview", "malahide",
            "howth", "marino", "ballybough", "north strand", "sheriff street",
            "east wall"
        ]
        aliases_to_predict = [
            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 171, 8,
            19, 20
        ]
        spans = [[12, 13], [14, 15], [15, 16], [17, 18], [18, 19], [19, 20],
                 [20, 21], [21, 22], [22, 23], [23, 24], [24, 25], [25, 26],
                 [26, 27], [27, 28], [28, 29], [29, 30], [30, 31], [38, 39],
                 [39, 41], [41, 43], [43, 45]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # Truth
        true_phrase_arr = [
            "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>"
            .split(),
            "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>"
            .split(),
            "East edition The original east edition is distributed to areas such as Coolock Kilmore Swords Darndale Priorswood Santry Donnycarney Baldoyle Sutton Donaghmede Artane Whitehall Kilbarrack Raheny Clontarf Fairview Malahide Howth Marino and the north east inner city Summerhill Ballybough North Strand Sheriff Street East Wall <pad> <pad> <pad> <pad> <pad>"
            .split()
        ]
        true_spans_arr = [[[12, 13], [14, 15], [15, 16], [17, 18], [18, 19],
                           [19, 20], [20, 21], [21, 22]],
                          [[20, 21], [21, 22], [22, 23], [23, 24], [24, 25],
                           [25, 26], [26, 27], [27, 28], [28, 29]],
                          [[27, 28], [28, 29], [29, 30], [30, 31], [38, 39],
                           [39, 41], [41, 43], [43, 45]]]
        true_alias_to_predict_arr = [[0, 1, 2, 3, 4, 5, 6],
                                     [1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 6, 7]]
        true_aliases_arr = [[
            "coolock", "swords", "darndale", "santry", "donnycarney",
            "baldoyle", "sutton", "donaghmede"
        ],
                            [
                                "sutton", "donaghmede", "artane", "whitehall",
                                "kilbarrack", "raheny", "clontarf", "fairview",
                                "malahide"
                            ],
                            [
                                "fairview", "malahide", "howth", "marino",
                                "ballybough", "north strand", "sheriff street",
                                "east wall"
                            ]]

        assert len(idxs_arr) == 3
        assert len(aliases_to_predict_arr) == 3
        assert len(spans_arr) == 3
        assert len(phrase_tokens_arr) == 3
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Example 2
        max_aliases = 10
        max_seq_len = 100

        # 84|0~*~1|kentucky~*~green|621151~*~478999|8:9~*~9:10|The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools
        sentence = "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools"
        aliases = ["kentucky", "green"]
        aliases_to_predict = [0, 1]
        spans = [[8, 9], [9, 10]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The Assembly also reserved tolls collected on the Kentucky Green and Barren rivers for education and passed a two percent property tax to fund the state s schools <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>"
            .split()
        ]
        true_spans_arr = [[[8, 9], [9, 10]]]
        true_alias_to_predict_arr = [[0, 1]]
        true_aliases_arr = [["kentucky", "green"]]

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #9
0
    def test_seq_length(self):

        # Test maximum sequence length
        max_aliases = 30
        max_seq_len = 12

        # Manual data
        sentence = 'The big alias1 ran away from dogs and multi word alias2 and alias3 because we want our cat and our alias5'
        aliases = ["The big", "alias3", "alias5"]
        aliases_to_predict = [0, 1, 2]
        spans = [[0, 2], [12, 13], [20, 21]]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "The big alias1 ran away from dogs and multi word alias2 and".
            split(),
            "word alias2 and alias3 because we want our cat and our alias5".
            split()
        ]
        true_spans_arr = [[[0, 2]], [[3, 4], [11, 12]]]
        true_alias_to_predict_arr = [[0], [0, 1]]
        true_aliases_arr = [["The big"], ["alias3", "alias5"]]

        assert len(idxs_arr) == 2
        assert len(aliases_to_predict_arr) == 2
        assert len(spans_arr) == 2
        assert len(phrase_tokens_arr) == 2
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])

        # Now test with modified aliases to perdict
        aliases_to_predict = [1, 2]

        # Run function
        args = parser_utils.get_full_config("test/run_args/test_data.json")
        word_symbols = data_utils.load_wordsymbols(args.data_config)
        idxs_arr, aliases_to_predict_arr, spans_arr, phrase_tokens_arr = split_sentence(
            max_aliases, sentence, spans, aliases, aliases_to_predict,
            max_seq_len, word_symbols)

        # True data
        true_phrase_arr = [
            "word alias2 and alias3 because we want our cat and our alias5".
            split()
        ]
        true_spans_arr = [[[3, 4], [11, 12]]]
        true_alias_to_predict_arr = [[0, 1]]
        true_aliases_arr = [["alias3", "alias5"]]

        assert len(idxs_arr) == 1
        assert len(aliases_to_predict_arr) == 1
        assert len(spans_arr) == 1
        assert len(phrase_tokens_arr) == 1
        for i in range(len(idxs_arr)):
            self.assertEqual(len(phrase_tokens_arr[i]), max_seq_len)
            self.assertEqual(phrase_tokens_arr[i], true_phrase_arr[i])
            self.assertEqual(spans_arr[i], true_spans_arr[i])
            self.assertEqual(aliases_to_predict_arr[i],
                             true_alias_to_predict_arr[i])
            self.assertEqual([aliases[idx] for idx in idxs_arr[i]],
                             true_aliases_arr[i])
Пример #10
0
        def label_mentions(self, text_list):
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert type(text_list) is list and len(text_list) > 0 and type(
                text_list[0]) is str, f"We only accept inputs of strings and lists of strings"

        ebs = self.args.run_config.eval_batch_size
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(enumerate(text_list), desc="Prepping data", total=len(text_list)):
            sample = self.extract_mentions(text)
            total_start_exs += len(sample['aliases'])
            char_spans = self.get_char_spans(sample['spans'], text)

            final_char_spans.append(char_spans)

            idxs_arr, aliases_to_predict_per_split, spans_arr, phrase_tokens_arr, pos_idxs = sentence_utils.split_sentence(
                max_aliases=self.args.data_config.max_aliases,
                phrase=sample['sentence'],
                spans=sample['spans'],
                aliases=sample['aliases'],
                aliases_seen_by_model=[i for i in range(len(sample['aliases']))],
                seq_len=self.args.data_config.max_word_token_len,
                word_symbols=self.word_db)
            aliases_arr = [[sample['aliases'][idx] for idx in idxs] for idxs in idxs_arr]
            old_spans_arr = [[sample['spans'][idx] for idx in idxs] for idxs in idxs_arr]
            qids_arr = [[sample['qids'][idx] for idx in idxs] for idxs in idxs_arr]
            word_indices_arr = [self.word_db.convert_tokens_to_ids(pt) for pt in phrase_tokens_arr]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert len(aliases_to_predict_arr) >= 0, f'There are no aliases to predict for an example. This should not happen at this point.'
                assert len(aliases_arr[
                               sub_idx]) <= self.args.data_config.max_aliases, f'Each example should have no more that {self.args.data_config.max_aliases} max aliases. {sample} does.'

                example_aliases = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_aliases_locs_end = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_alias_list_pos = np.ones(self.args.data_config.max_aliases) * PAD_ID
                example_true_entities = np.ones(self.args.data_config.max_aliases) * PAD_ID

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.args.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(alias_qids == qids_arr[sub_idx][mention_idx])[0][0] + (
                            not self.args.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(batch_example_aliases_locs_start, device=self.device)
        batch_example_aliases_locs_end = torch.tensor(batch_example_aliases_locs_end, device=self.device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities, device=self.device)
        batch_word_indices = torch.tensor(batch_word_indices, device=self.device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(range(0, batch_example_aliases.shape[0], ebs), desc="Evaluating model"):
            # entity indices from alias table (these are the candidates)
            batch_entity_indices = self.alias_table(batch_example_aliases[b_i:b_i + ebs])

            # all CPU embs have to retrieved on the fly
            batch_on_the_fly_data = {}
            for emb_name, emb in self.batch_on_the_fly_embs.items():
                batch_prep = []
                for j in range(b_i, min(b_i + ebs, batch_example_aliases.shape[0])):
                    batch_prep.append(emb.batch_prep(batch_example_aliases[j], batch_entity_indices[j - b_i]))
                batch_on_the_fly_data[emb_name] = torch.tensor(batch_prep, device=self.device)

            alias_idx_pair_sent = [batch_example_aliases_locs_start[b_i:b_i + ebs], batch_example_aliases_locs_end[b_i:b_i + ebs]]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            entity_indices = torch.tensor(batch_entity_indices, device=self.device)

            outs, entity_pack, _ = self.model(
                alias_idx_pair_sent=alias_idx_pair_sent,
                word_indices=word_indices,
                alias_indices=alias_indices,
                entity_indices=entity_indices,
                batch_prepped_data={},
                batch_on_the_fly_data=batch_on_the_fly_data)

            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================

            final_loss_vals = outs[DISAMBIG][FINAL_LOSS]
            # recover predictions
            probs = torch.exp(eval_utils.masked_class_logsoftmax(pred=final_loss_vals,
                                                                 mask=~entity_pack.mask, dim=2))
            max_probs, max_probs_indices = probs.max(2)
            for ex_i in range(final_loss_vals.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                subsplit_idx = batch_subsplit_idx[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(self.args.data_config.train_in_candidates,
                                                                    self.entity_db,
                                                                    batch_aliases_arr[b_i + ex_i])

                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].detach().cpu().numpy().reshape(self.args.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(self.entity_db.get_title(pred_qid) if pred_qid != 'NC' else 'NC')
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, f"Something went wrong and we have predicted fewer mentions than extracted. Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        return final_pred_cands, final_pred_probs, final_titles, final_all_cands, final_cand_probs, final_spans, final_aliases
Пример #11
0
    def label_mentions(
        self,
        text_list=None,
        label_func=find_aliases_in_sentence_tag,
        extracted_examples=None,
    ):
        """Extracts mentions and runs disambiguation. If user provides extracted_examples, we will ignore text_list

        Args:
            text_list: list of text to disambiguate (or single string) (can be None if extracted_examples is not None)
            label_func: mention extraction funciton (optional)
            extracted_examples: List of Dicts of keys "sentence", "aliases", "spans", "cands" (QIDs) (optional)

        Returns: Dict of

            * ``qids``: final predicted QIDs,
            * ``probs``: final predicted probs,
            * ``titles``: final predicted titles,
            * ``cands``: all entity candidates,
            * ``cand_probs``: probabilities of all candidates,
            * ``spans``: final extracted word spans,
            * ``aliases``: final extracted aliases,
            * ``embs``: final entity contextualized embeddings (if return_embs is True)
            * ``cand_embs``: final candidate entity contextualized embeddings (if return_embs is True)
        """
        # Check inputs are sane
        do_extract_mentions = True
        if extracted_examples is not None:
            do_extract_mentions = False
            assert (type(extracted_examples) is list
                    ), f"Must provide a list of Dics for extracted_examples"
            check_ex = extracted_examples[0]
            assert (len({
                "sentence", "aliases", "spans", "cands"
            }.intersection(check_ex.keys())) == 4), (
                f"You must have keys of sentence, aliases, spans, and cands for extracted_examples. You have"
                f"{extracted_examples.keys()}")
        else:
            assert (
                text_list is not None
            ), f"If you do not provide extracted_examples you must provide text_list"

        if text_list is None:
            assert extracted_examples is not None, (
                f"If you do not provide text_list "
                f"you must provide extracted_exampels")
        else:
            if type(text_list) is str:
                text_list = [text_list]
            else:
                assert (
                    type(text_list) is list and len(text_list) > 0
                    and type(text_list[0]) is str
                ), f"We only accept inputs of strings and lists of strings"

        # Get number of examples
        if extracted_examples is not None:
            num_exs = len(extracted_examples)
        else:
            num_exs = len(text_list)

        ebs = int(self.config.run_config.eval_batch_size)
        self.config.data_config.max_aliases = int(
            self.config.data_config.max_aliases)
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_qid_cands = []
        batch_example_eid_cands = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_example_aliases = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq in tqdm(
                range(num_exs),
                desc="Prepping data",
                total=num_exs,
                disable=not self.verbose,
        ):
            if do_extract_mentions:
                sample = self.extract_mentions(text_list[idx_unq], label_func)
            else:
                sample = extracted_examples[idx_unq]
                # Add the unk qids and gold values
                sample["qids"] = ["Q-1" for _ in range(len(sample["aliases"]))]
                sample["gold"] = [True for _ in range(len(sample["aliases"]))]
            total_start_exs += len(sample["aliases"])
            char_spans = self.get_char_spans(sample["spans"],
                                             sample["sentence"])

            final_char_spans.append(char_spans)

            (
                idxs_arr,
                aliases_to_predict_per_split,
                spans_arr,
                phrase_tokens_arr,
                pos_idxs,
            ) = sentence_utils.split_sentence(
                max_aliases=self.config.data_config.max_aliases,
                phrase=sample["sentence"],
                spans=sample["spans"],
                aliases=sample["aliases"],
                aliases_seen_by_model=list(range(len(sample["aliases"]))),
                seq_len=self.config.data_config.max_seq_len,
                is_bert=True,
                tokenizer=self.tokenizer,
            )
            aliases_arr = [[sample["aliases"][idx] for idx in idxs]
                           for idxs in idxs_arr]
            old_spans_arr = [[sample["spans"][idx] for idx in idxs]
                             for idxs in idxs_arr]
            qids_arr = [[sample["qids"][idx] for idx in idxs]
                        for idxs in idxs_arr]
            word_indices_arr = [
                self.tokenizer.convert_tokens_to_ids(pt)
                for pt in phrase_tokens_arr
            ]
            # iterate over each sample in the split
            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert (
                    len(aliases_to_predict_arr) >= 0
                ), f"There are no aliases to predict for an example. This should not happen at this point."
                assert (
                    len(aliases_arr[sub_idx]) <=
                    self.config.data_config.max_aliases
                ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases."

                example_aliases_locs_start = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_aliases_locs_end = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_alias_list_pos = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_true_entities = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_qid_cands = [[
                    "-1" for _ in range(
                        get_max_candidates(self.entity_db,
                                           self.config.data_config))
                ] for _ in range(self.config.data_config.max_aliases)]
                example_eid_cands = [[
                    -1 for _ in range(
                        get_max_candidates(self.entity_db,
                                           self.config.data_config))
                ] for _ in range(self.config.data_config.max_aliases)]
                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][
                        mention_idx]
                    # generate indexes into alias table.
                    alias_qids = np.array(sample["cands"][mention_idx])
                    # first entry is the non candidate class (NC and eid 0) - used when train in cands is false
                    # if we train in candidates, this gets overwritten
                    example_qid_cands[mention_idx][0] = "NC"
                    example_qid_cands[mention_idx][
                        (not self.config.data_config.train_in_candidates
                         ):len(alias_qids) +
                        (not self.config.data_config.train_in_candidates
                         )] = sample["cands"][mention_idx]
                    example_eid_cands[mention_idx][0] = 0
                    example_eid_cands[mention_idx][
                        (not self.config.data_config.train_in_candidates
                         ):len(alias_qids) +
                        (not self.config.data_config.train_in_candidates)] = [
                            self.entity_db.get_eid(q)
                            for q in sample["cands"][mention_idx]
                        ]
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.config.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one
                        # (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is
                        # reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(
                            alias_qids == qids_arr[sub_idx][mention_idx]
                        )[0][0] + (
                            not self.config.data_config.train_in_candidates)
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][
                        mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence
                    # and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_qid_cands.append(example_qid_cands)
                batch_example_eid_cands.append(example_eid_cands)
                batch_example_aliases_locs_start.append(
                    example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_example_aliases.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_eid_cands = torch.tensor(batch_example_eid_cands).long()
        batch_example_aliases_locs_start = torch.tensor(
            batch_example_aliases_locs_start)
        batch_example_aliases_locs_end = torch.tensor(
            batch_example_aliases_locs_end)
        batch_example_true_entities = torch.tensor(batch_example_true_entities)
        batch_word_indices = torch.tensor(batch_word_indices)

        final_pred_cands = [[] for _ in range(num_exs)]
        final_all_cands = [[] for _ in range(num_exs)]
        final_cand_probs = [[] for _ in range(num_exs)]
        final_pred_probs = [[] for _ in range(num_exs)]
        final_entity_embs = [[] for _ in range(num_exs)]
        final_entity_cand_embs = [[] for _ in range(num_exs)]
        final_titles = [[] for _ in range(num_exs)]
        final_spans = [[] for _ in range(num_exs)]
        final_aliases = [[] for _ in range(num_exs)]
        for b_i in tqdm(
                range(0, batch_word_indices.shape[0], ebs),
                desc="Evaluating model",
                disable=not self.verbose,
        ):
            start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs]
            end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            eid_cands = batch_example_eid_cands[b_i:b_i + ebs]
            x_dict = self.get_forward_batch(start_span_idx, end_span_idx,
                                            word_indices, eid_cands)
            x_dict["guid"] = torch.arange(b_i,
                                          b_i + ebs,
                                          device=self.torch_device)

            with torch.no_grad():
                res = self.model(  # type: ignore
                    uids=x_dict["guid"],
                    X_dict=x_dict,
                    Y_dict=None,
                    task_to_label_dict=self.task_to_label_dict,
                    return_action_outputs=self.return_embs,
                )
            del x_dict
            if self.return_embs:
                (uid_bdict, _, prob_bdict, _, out_bdict) = res
                output_embs = out_bdict[NED_TASK][f"{PRED_LAYER}_ent_embs"]
            else:
                output_embs = None
                (uid_bdict, _, prob_bdict, _) = res
            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================
            # recover predictions
            probs = prob_bdict[NED_TASK]
            max_probs = probs.max(2)
            max_probs_indices = probs.argmax(2)
            for ex_i in range(probs.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                entity_cands = batch_example_qid_cands[b_i + ex_i]
                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].reshape(
                    self.config.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(
                        batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(
                                probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            if self.return_embs:
                                final_entity_embs[idx_unq].append(
                                    output_embs[ex_i][alias_idx][pred_idx])
                                final_entity_cand_embs[idx_unq].append(
                                    output_embs[ex_i][alias_idx])
                            final_aliases[idx_unq].append(
                                batch_example_aliases[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(
                                batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(
                                self.entity_db.get_title(pred_qid)
                                if pred_qid != "NC" else "NC")
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, (
            f"Something went wrong and we have predicted fewer mentions than extracted. "
            f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        )
        res_dict = {
            "qids": final_pred_cands,
            "probs": final_pred_probs,
            "titles": final_titles,
            "cands": final_all_cands,
            "cand_probs": final_cand_probs,
            "spans": final_spans,
            "aliases": final_aliases,
        }
        if self.return_embs:
            res_dict["embs"] = final_entity_embs
            res_dict["cand_embs"] = final_entity_cand_embs
        return res_dict