Ejemplo n.º 1
0
    def __init__(
        self,
        wiki_link_db_path: str,
        model_redirect_mappings_path: str,
        link_redirect_mappings_path: str,
        entity_vocab_path: str,
        source_language: str = "en",
        inter_wiki_path: str = None,
        multilingual_entity_db_path: Dict[str, str] = None,
        min_mention_link_prob: float = 0.01,
        max_mention_length: int = 10,
    ):
        self.tokenizer = None
        self.wiki_link_db = WikiLinkDB(wiki_link_db_path)
        self.model_redirect_mappings: Dict[str, str] = joblib.load(model_redirect_mappings_path)
        self.link_redirect_mappings: Dict[str, str] = joblib.load(link_redirect_mappings_path)

        self.source_language = source_language
        if inter_wiki_path is not None:
            self.inter_wiki_db = InterwikiDB.load(inter_wiki_path)
        else:
            self.inter_wiki_db = None

        self.entity_vocab = EntityVocab(entity_vocab_path)

        multilingual_entity_db_path = multilingual_entity_db_path or {}
        self.entity_db_dict = {lang: EntityDB(path) for lang, path in multilingual_entity_db_path.items()}

        self.min_mention_link_prob = min_mention_link_prob

        self.max_mention_length = max_mention_length
Ejemplo n.º 2
0
def build_wikipedia_pretraining_dataset(dump_db_file: str, tokenizer_name: str,
                                        entity_vocab_file: str,
                                        output_dir: str,
                                        sentence_tokenizer: str, **kwargs):
    dump_db = DumpDB(dump_db_file)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    sentence_tokenizer = SentenceTokenizer.from_name(sentence_tokenizer)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    entity_vocab = EntityVocab(entity_vocab_file)
    WikipediaPretrainingDataset.build(dump_db, tokenizer, sentence_tokenizer,
                                      entity_vocab, output_dir, **kwargs)
Ejemplo n.º 3
0
def test_save_and_load(multilingual_entity_vocab):
    with tempfile.NamedTemporaryFile() as f:
        multilingual_entity_vocab.save(f.name)
        entity_vocab2 = EntityVocab(f.name)

        assert len(multilingual_entity_vocab) == len(entity_vocab2)

        # check if the two vocabs are identical after save and load
        for ent_id in range(len(multilingual_entity_vocab)):
            entities1 = multilingual_entity_vocab.inv_vocab[ent_id]
            entities2 = entity_vocab2.inv_vocab[ent_id]
            assert set(entities1) == set(entities2)
            assert multilingual_entity_vocab.counter[entities1[0]] == entity_vocab2.counter[entities2[0]]
            assert multilingual_entity_vocab.vocab[entities1[0]] == entity_vocab2.vocab[entities2[0]]
Ejemplo n.º 4
0
def build_medmentions_pretraining_dataset(medmentions_db_file: str,
                                          tokenizer_name: str,
                                          entity_vocab_file: str,
                                          output_dir: str,
                                          sentence_tokenizer: str, **kwargs):
    medmentions_db = MedMentionsDB(medmentions_db_file)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
    sentence_tokenizer = SentenceTokenizer.from_name(sentence_tokenizer)
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    entity_vocab = EntityVocab(entity_vocab_file)
    MedMentionsPretrainingDataset.build(medmentions_db, tokenizer,
                                        sentence_tokenizer, entity_vocab,
                                        output_dir, **kwargs)
Ejemplo n.º 5
0
 def entity_vocab(self):
     vocab_file_path = get_entity_vocab_file_path(self._dataset_dir)
     return EntityVocab(vocab_file_path)
Ejemplo n.º 6
0
    def build(
        cls,
        dump_db: DumpDB,
        tokenizer: PreTrainedTokenizer,
        sentence_tokenizer: SentenceTokenizer,
        entity_vocab: EntityVocab,
        output_dir: str,
        max_seq_length: int,
        max_entity_length: int,
        max_mention_length: int,
        min_sentence_length: int,
        include_sentences_without_entities: bool,
        include_unk_entities: bool,
        pool_size: int,
        chunk_size: int,
        max_num_documents: int,
    ):

        target_titles = [
            title for title in dump_db.titles()
            if not (":" in title and title.lower().split(":")[0] in
                    ("image", "file", "category"))
        ]
        random.shuffle(target_titles)

        if max_num_documents is not None:
            target_titles = target_titles[:max_num_documents]

        max_num_tokens = max_seq_length - 2  # 2 for [CLS] and [SEP]

        tokenizer.save_pretrained(output_dir)

        entity_vocab.save(os.path.join(output_dir, ENTITY_VOCAB_FILE))
        number_of_items = 0
        tf_file = os.path.join(output_dir, DATASET_FILE)
        options = tf.io.TFRecordOptions(
            tf.compat.v1.io.TFRecordCompressionType.GZIP)
        with TFRecordWriter(tf_file, options=options) as writer:
            with tqdm(total=len(target_titles)) as pbar:
                initargs = (
                    dump_db,
                    tokenizer,
                    sentence_tokenizer,
                    entity_vocab,
                    max_num_tokens,
                    max_entity_length,
                    max_mention_length,
                    min_sentence_length,
                    include_sentences_without_entities,
                    include_unk_entities,
                )
                with closing(
                        Pool(pool_size,
                             initializer=WikipediaPretrainingDataset.
                             _initialize_worker,
                             initargs=initargs)) as pool:
                    for ret in pool.imap(
                            WikipediaPretrainingDataset._process_page,
                            target_titles,
                            chunksize=chunk_size):
                        for data in ret:
                            writer.write(data)
                            number_of_items += 1
                        pbar.update()

        with open(os.path.join(output_dir, METADATA_FILE),
                  "w") as metadata_file:
            json.dump(
                dict(
                    number_of_items=number_of_items,
                    max_seq_length=max_seq_length,
                    max_entity_length=max_entity_length,
                    max_mention_length=max_mention_length,
                    min_sentence_length=min_sentence_length,
                    tokenizer_class=tokenizer.__class__.__name__,
                    language=dump_db.language,
                ),
                metadata_file,
                indent=2,
            )
Ejemplo n.º 7
0
def multilingual_entity_vocab():
    return EntityVocab(MULTILINGUAL_ENTITY_VOCAB_FIXTURE_FILE)
Ejemplo n.º 8
0
def entity_vocab():
    return EntityVocab(ENTITY_VOCAB_FIXTURE_FILE)
Ejemplo n.º 9
0
    def __init__(
        self,
        pretrained_weight_path: str,
        pretrained_metadata_path: str,
        entity_vocab_path: str = None,
        train_parameters: bool = True,
        gradient_checkpointing: bool = False,
        num_special_mask_embeddings: int = None,
        output_entity_embeddings: bool = False,
        num_additional_special_tokens: int = None,
        discard_entity_embeddings: bool = False,
    ) -> None:
        """

        Parameters
        ----------
        pretrained_weight_path: `str`
            Path to the luke pre-trained weight.

        pretrained_metadata_path: `str`
            Path to the luke pre-trained metadata, typically stored under the same directory as pretrained_weight_path.

        entity_vocab_path: `str`
            Path to the luke entity vocabulary.

        train_parameters: `bool`
            Decide if tunening or freezing pre-trained weights.

        gradient_checkpointing: `bool`
            Enable gradient checkpoinitng, which significantly reduce memory usage.

        num_special_mask_embeddings: `int`
            If specified, the model discard all the entity embeddings
            and only use the number of embeddings initialized with [MASK].
            This is used with the tasks such as named entity recognition (num_special_mask_embeddings=1),
            or relation classification (num_special_mask_embeddings=2).

        output_entity_embeddings: `bool`
            If specified, the model returns entity embeddings instead of token embeddings.
            If you need both, please use PretrainedLukeEmbedderWithEntity.

        num_additional_special_tokens: `int`
            Used when adding special tokens to the pre-trained vocabulary.
        discard_entity_embeddings: `bool`
            Replace entity embeddings with a dummy vector to save memory.
        """
        super().__init__()

        self.metadata = json.load(open(pretrained_metadata_path,
                                       "r"))["model_config"]
        if entity_vocab_path is not None:
            self.entity_vocab = EntityVocab(entity_vocab_path)
        else:
            self.entity_vocab = None

        model_weights = torch.load(pretrained_weight_path,
                                   map_location=torch.device("cpu"))
        self.num_special_mask_embeddings = num_special_mask_embeddings
        if num_special_mask_embeddings:
            assert self.entity_vocab is not None
            pad_id = self.entity_vocab.special_token_ids[PAD_TOKEN]
            mask_id = self.entity_vocab.special_token_ids[MASK_TOKEN]
            self.metadata[
                "entity_vocab_size"] = 1 + num_special_mask_embeddings
            entity_emb = model_weights[
                "entity_embeddings.entity_embeddings.weight"]
            mask_emb = entity_emb[mask_id].unsqueeze(0)
            pad_emb = entity_emb[pad_id].unsqueeze(0)
            model_weights[
                "entity_embeddings.entity_embeddings.weight"] = torch.cat(
                    [pad_emb] +
                    [mask_emb for _ in range(num_special_mask_embeddings)])

        if discard_entity_embeddings:
            self.metadata["entity_vocab_size"] = 1
            model_weights[
                "entity_embeddings.entity_embeddings.weight"] = torch.zeros(
                    1, self.metadata["entity_emb_size"])

        config = LukeConfig(
            entity_vocab_size=self.metadata["entity_vocab_size"],
            bert_model_name=self.metadata["bert_model_name"],
            entity_emb_size=self.metadata["entity_emb_size"],
            **AutoConfig.from_pretrained(
                self.metadata["bert_model_name"]).to_dict(),
        )
        config.gradient_checkpointing = gradient_checkpointing

        self.output_entity_embeddings = output_entity_embeddings

        self.luke_model = LukeModel(config)
        self.luke_model.load_state_dict(model_weights, strict=False)

        if num_additional_special_tokens:
            word_emb = self.luke_model.embeddings.word_embeddings.weight
            embed_size = word_emb.size(1)
            additional_embs = [
                torch.rand(1, embed_size)
                for _ in range(num_additional_special_tokens)
            ]
            augmented_weight = torch.nn.Parameter(
                torch.cat([word_emb] + additional_embs, dim=0))
            self.luke_model.embeddings.word_embeddings.weight = augmented_weight

        if not train_parameters:
            for param in self.transformer_model.parameters():
                param.requires_grad = False
Ejemplo n.º 10
0
class WikiMentionDetector(FromParams):
    """
    Detect entity mentions in text from Wikipedia articles.
    """

    def __init__(
        self,
        wiki_link_db_path: str,
        model_redirect_mappings_path: str,
        link_redirect_mappings_path: str,
        entity_vocab_path: str,
        source_language: str = "en",
        inter_wiki_path: str = None,
        multilingual_entity_db_path: Dict[str, str] = None,
        min_mention_link_prob: float = 0.01,
        max_mention_length: int = 10,
    ):
        self.tokenizer = None
        self.wiki_link_db = WikiLinkDB(wiki_link_db_path)
        self.model_redirect_mappings: Dict[str, str] = joblib.load(model_redirect_mappings_path)
        self.link_redirect_mappings: Dict[str, str] = joblib.load(link_redirect_mappings_path)

        self.source_language = source_language
        if inter_wiki_path is not None:
            self.inter_wiki_db = InterwikiDB.load(inter_wiki_path)
        else:
            self.inter_wiki_db = None

        self.entity_vocab = EntityVocab(entity_vocab_path)

        multilingual_entity_db_path = multilingual_entity_db_path or {}
        self.entity_db_dict = {lang: EntityDB(path) for lang, path in multilingual_entity_db_path.items()}

        self.min_mention_link_prob = min_mention_link_prob

        self.max_mention_length = max_mention_length

    def set_tokenizer(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer

    def get_mention_candidates(self, title: str) -> Dict[str, str]:
        """
        Returns a dict of [mention, entity (title)]
        """
        title = self.link_redirect_mappings.get(title, title)

        # mention_to_entity
        mention_candidates: Dict[str, str] = {}
        ambiguous_mentions: Set[str] = set()

        for link in self.wiki_link_db.get(title):
            if link.link_prob < self.min_mention_link_prob:
                continue

            link_text = self._normalize_mention(link.text)
            if link_text in mention_candidates and mention_candidates[link_text] != link.title:
                ambiguous_mentions.add(link_text)
                continue

            mention_candidates[link_text] = link.title

        for link_text in ambiguous_mentions:
            del mention_candidates[link_text]
        return mention_candidates

    def _detect_mentions(self, tokens: List[str], mention_candidates: Dict[str, str], language: str) -> List[Mention]:
        mentions = []
        cur = 0
        for start, token in enumerate(tokens):
            if start < cur:
                continue

            for end in range(min(start + self.max_mention_length, len(tokens)), start, -1):

                mention_text = self.tokenizer.convert_tokens_to_string(tokens[start:end])
                mention_text = self._normalize_mention(mention_text)
                if mention_text in mention_candidates:
                    cur = end
                    title = mention_candidates[mention_text]
                    title = self.model_redirect_mappings.get(title, title)  # resolve mismatch between two dumps
                    if self.entity_vocab.contains(title, language):
                        mention = Mention(Entity(title, language), start, end)
                        mentions.append(mention)
                    break

        return mentions

    def detect_mentions(self, tokens: List[Token], title: str, language: str) -> List[Mention]:

        if self.tokenizer is None:
            raise RuntimeError("self.tokenizer is None. Did you call self.set_tokenizer()?")

        source_mention_candidates = self.get_mention_candidates(title)

        if language == self.source_language:
            target_mention_candidates = source_mention_candidates
        else:
            if self.inter_wiki_db is None:
                raise ValueError(
                    f"You need InterWikiDB to detect mentions from other languages except for {self.source_language}."
                )
            source_entities = list(source_mention_candidates.values())

            target_entities = []
            for ent in source_entities:
                translated_ent = self.inter_wiki_db.get_title_translation(ent, self.source_language, language)
                if translated_ent is not None:
                    target_entities.append(translated_ent)

            target_mention_candidates = {}
            for target_entity in target_entities:
                for entity, mention, count in self.entity_db_dict[language].query(target_entity):
                    target_mention_candidates[mention] = entity

        target_mentions = self._detect_mentions([t.text for t in tokens], target_mention_candidates, language)

        return target_mentions

    @staticmethod
    def _normalize_mention(text: str):
        return " ".join(text.lower().split(" ")).strip()

    def mentions_to_entity_features(self, tokens: List[Token], mentions: List[Mention]) -> Dict:

        if len(mentions) == 0:
            entity_ids = [0]
            entity_type_ids = [0]
            entity_attention_mask = [0]
            entity_position_ids = [[-1 for y in range(self.max_mention_length)]]
        else:
            entity_ids = [0] * len(mentions)
            entity_type_ids = [0] * len(mentions)
            entity_attention_mask = [1] * len(mentions)
            entity_position_ids = [[-1 for y in range(self.max_mention_length)] for x in range(len(mentions))]

            for i, (entity, start, end) in enumerate(mentions):
                entity_ids[i] = self.entity_vocab.get_id(entity.title, entity.language)
                entity_position_ids[i][: end - start] = range(start, end)

                if tokens[start].type_id is not None:
                    entity_type_ids[i] = tokens[start].type_id

        return {
            "entity_ids": entity_ids,
            "entity_attention_mask": entity_attention_mask,
            "entity_position_ids": entity_position_ids,
            "entity_type_ids": entity_type_ids,
        }