Пример #1
0
    def test_wiki_candidate_generator_no_candidates(self):
        fake_entity_world = {
            "Germany": "11867",
            "United_Kingdom": "31717",
            "European_Commission": "42336"
        }

        candidate_generator = WikiCandidateMentionGenerator(
            'tests/fixtures/linking/priors.txt',
            entity_world_path=fake_entity_world)

        candidates = candidate_generator.get_mentions_raw_text(".")
        assert candidates['candidate_entities'] == [['@@PADDING@@']]
Пример #2
0
    def __init__(self,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 entity_indexer: TokenIndexer = TokenIndexer.from_params(
                     Params(INDEXER_DEFAULT)),
                 granularity: str = "sentence",
                 mention_generator: MentionGenerator = None,
                 should_remap_span_indices: bool = True,
                 entity_disambiguation_only: bool = False,
                 extra_candidate_generators: Dict[str,
                                                  MentionGenerator] = None):

        lazy = False
        super().__init__(lazy)
        self.token_indexers = token_indexers or {
            "token": SingleIdTokenIndexer("token")
        }
        self.entity_indexer = {"ids": entity_indexer}
        self.separator = {"*NL*"}
        if granularity == "sentence":
            self.separator.add(".")

        if granularity not in {"sentence", "paragraph"}:
            raise ConfigurationError(
                "Valid arguments for granularity are 'sentence' or 'paragraph'."
            )

        self.entity_disambiguation_only = entity_disambiguation_only
        self.mention_generator = mention_generator or WikiCandidateMentionGenerator(
        )
        self.should_remap_span_indices = should_remap_span_indices

        self.extra_candidate_generators = extra_candidate_generators
Пример #3
0
    def test_read(self):

        candidate_generator = WikiCandidateMentionGenerator(
            "tests/fixtures/linking/priors.txt")
        assert len(candidate_generator.p_e_m) == 50

        assert set(candidate_generator.p_e_m.keys()) == {
            'United States', 'Information', 'Wiki', 'France', 'English',
            'Germany', 'World War II', '2007', 'England', 'American', 'Canada',
            'Australia', 'Japan', '2008', 'India', '2006', 'Area Info',
            'London', 'German', 'About Company', 'French', 'United Kingdom',
            'Italy', 'en', 'California', 'China', '2005', 'New York', 'Spain',
            'Europe', 'British', '2004', 'New York City', 'Russia',
            'public domain', '2000', 'Brazil', 'Poland', 'micro-blogging',
            'Greek', 'New Zealand', '2003', 'Mexico', 'Italian', 'Ireland',
            'Wiki Image', 'Paris', 'USA', '[1]', 'Iran'
        }

        lower = candidate_generator.process("united states")
        string_list = candidate_generator.process(["united", "states"])
        upper = candidate_generator.process(["United", "States"])
        assert lower == upper == string_list
Пример #4
0
    def test_wiki_linking_reader_with_wordnet(self):
        def _get_indexer(namespace):
            return TokenIndexer.from_params(
                Params({
                    "type": "characters_tokenizer",
                    "tokenizer": {
                        "type": "word",
                        "word_splitter": {
                            "type": "just_spaces"
                        },
                    },
                    "namespace": namespace
                }))

        extra_generator = {
            'wordnet':
            WordNetCandidateMentionGenerator(
                'tests/fixtures/wordnet/entities_fixture.jsonl')
        }

        fake_entity_world = {
            "Germany": "11867",
            "United_Kingdom": "31717",
            "European_Commission": "42336"
        }
        candidate_generator = WikiCandidateMentionGenerator(
            'tests/fixtures/linking/priors.txt',
            entity_world_path=fake_entity_world)
        train_file = 'tests/fixtures/linking/aida.txt'

        reader = LinkingReader(mention_generator=candidate_generator,
                               entity_indexer=_get_indexer("entity_wiki"),
                               extra_candidate_generators=extra_generator)
        instances = reader.read(train_file)

        assert len(instances) == 2
Пример #5
0
    def test_wiki_candidate_generator_simple(self):
        candidate_generator = WikiCandidateMentionGenerator(
            'tests/fixtures/linking/priors.txt', )
        s = "Mexico is bordered to the north by the United States."

        # first candidate in each list
        candidates = candidate_generator.get_mentions_raw_text(s)
        first_prior = [
            span_candidates[0]
            for span_candidates in candidates['candidate_entities']
        ]
        assert first_prior == ['Mexico', 'United_States']

        # now do it randomly
        candidate_generator.random_candidates = True
        candidate_generator.p_e_m_keys_for_sampling = list(
            candidate_generator.p_e_m.keys())
        candidates = candidate_generator.get_mentions_raw_text(s)
        first_prior = [
            span_candidates[0]
            for span_candidates in candidates['candidate_entities']
        ]
        assert first_prior != ['Mexico', 'United_States']
Пример #6
0
    def __init__(self,
                 bert_name="bert-base-cased",
                 ebert_name="wikipedia2vec-base-cased",
                 mapper_name="linear",
                 device=0,
                 ent_prefix="ENTITY/",
                 left_pattern=["[MASK]", "/"],
                 right_pattern=["*"],
                 granularity="document",
                 max_len=512,
                 margin=1,
                 dynamic_max_len=True,
                 max_candidates=1000,
                 do_use_priors=False,
                 do_prime_mask=False,
                 seed=0):
        """
        bert_name : name of bert model (compatible with pytorch_transformers)
        ebert_name : name of ebert embeddings (file $ebert_name$ should exist)
        device : CUDA device
        ent_prefix : prefix used to distinguish entities from words
        left_pattern : pattern introduced before candidate mention
        right_pattern : pattern introduced after candidate mention
        (one of left_pattern, right_pattern should contain [MASK])
        granularity : splitting AIDA at document or paragraph level
        max_len : maximum length of input to BERT
        margin : margin for max margin training (not used)
        dymamic_max_len : whether to use dynamic padding
        max_candidates : number of candidates per potential mention
        do_use_priors : whether to use candidate generator priors (not recommended)
        do_prime_mask : whether to prime [MASK] token with average candidate (recommended)
        seed : random seed
        """

        self.dynamic_max_len = dynamic_max_len
        self.device = device
        self.max_len = max_len
        self.left_pattern = left_pattern
        self.right_pattern = right_pattern
        self.granularity = granularity
        self.ent_prefix = ent_prefix

        self.rnd = np.random.RandomState(seed)
        self.do_use_priors = do_use_priors
        self.do_prime_mask = do_prime_mask

        assert self.granularity in ("paragraph", "document")
        assert self.max_len <= 512
        assert len([
            x for x in self.left_pattern + self.right_pattern if x == "[MASK]"
        ]) == 1

        self.tokenizer = BertTokenizer.from_pretrained(bert_name,
                                                       do_lower_case="uncased"
                                                       in bert_name)

        assert not any(
            [w.startswith(ent_prefix) for w in self.tokenizer.vocab.keys()])

        self.ebert_emb = load_embedding(ebert_name, prefix=self.ent_prefix)
        if mapper_name and (mapper_name != "None"):
            self.ebert_emb = MappedEmbedding(
                self.ebert_emb,
                load_mapper(f"{ebert_name}.{bert_name}.{mapper_name}.npy"))

        tmp_bert_model = BertModel.from_pretrained(bert_name)
        self.bert_emb = tmp_bert_model.embeddings.word_embeddings
        del tmp_bert_model

        self.model = EmbInputBertForMaskedEmbLM.from_pretrained(bert_name).to(
            device=self.device)

        null_vector = self.rnd.uniform(
            low=-self.model.config.initializer_range,
            high=self.model.config.initializer_range,
            size=(self.model.config.hidden_size, ))
        self.null_vector = Variable(torch.tensor(null_vector).to(
            dtype=torch.float, device=self.device),
                                    requires_grad=True)
        self.candidate_generator = WikiCandidateMentionGenerator(
            entity_world_path=None, max_candidates=max_candidates)

        if self.do_use_priors:
            self.null_bias = Variable(torch.zeros(
                (1, )).to(dtype=torch.float, device=self.device),
                                      requires_grad=True)
Пример #7
0
class EntityLinkingAsLM:
    NO_DECAY = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    SPECIAL_ENT_RGX = re.compile("@@.+?@@")

    def __init__(self,
                 bert_name="bert-base-cased",
                 ebert_name="wikipedia2vec-base-cased",
                 mapper_name="linear",
                 device=0,
                 ent_prefix="ENTITY/",
                 left_pattern=["[MASK]", "/"],
                 right_pattern=["*"],
                 granularity="document",
                 max_len=512,
                 margin=1,
                 dynamic_max_len=True,
                 max_candidates=1000,
                 do_use_priors=False,
                 do_prime_mask=False,
                 seed=0):
        """
        bert_name : name of bert model (compatible with pytorch_transformers)
        ebert_name : name of ebert embeddings (file $ebert_name$ should exist)
        device : CUDA device
        ent_prefix : prefix used to distinguish entities from words
        left_pattern : pattern introduced before candidate mention
        right_pattern : pattern introduced after candidate mention
        (one of left_pattern, right_pattern should contain [MASK])
        granularity : splitting AIDA at document or paragraph level
        max_len : maximum length of input to BERT
        margin : margin for max margin training (not used)
        dymamic_max_len : whether to use dynamic padding
        max_candidates : number of candidates per potential mention
        do_use_priors : whether to use candidate generator priors (not recommended)
        do_prime_mask : whether to prime [MASK] token with average candidate (recommended)
        seed : random seed
        """

        self.dynamic_max_len = dynamic_max_len
        self.device = device
        self.max_len = max_len
        self.left_pattern = left_pattern
        self.right_pattern = right_pattern
        self.granularity = granularity
        self.ent_prefix = ent_prefix

        self.rnd = np.random.RandomState(seed)
        self.do_use_priors = do_use_priors
        self.do_prime_mask = do_prime_mask

        assert self.granularity in ("paragraph", "document")
        assert self.max_len <= 512
        assert len([
            x for x in self.left_pattern + self.right_pattern if x == "[MASK]"
        ]) == 1

        self.tokenizer = BertTokenizer.from_pretrained(bert_name,
                                                       do_lower_case="uncased"
                                                       in bert_name)

        assert not any(
            [w.startswith(ent_prefix) for w in self.tokenizer.vocab.keys()])

        self.ebert_emb = load_embedding(ebert_name, prefix=self.ent_prefix)
        if mapper_name and (mapper_name != "None"):
            self.ebert_emb = MappedEmbedding(
                self.ebert_emb,
                load_mapper(f"{ebert_name}.{bert_name}.{mapper_name}.npy"))

        tmp_bert_model = BertModel.from_pretrained(bert_name)
        self.bert_emb = tmp_bert_model.embeddings.word_embeddings
        del tmp_bert_model

        self.model = EmbInputBertForMaskedEmbLM.from_pretrained(bert_name).to(
            device=self.device)

        null_vector = self.rnd.uniform(
            low=-self.model.config.initializer_range,
            high=self.model.config.initializer_range,
            size=(self.model.config.hidden_size, ))
        self.null_vector = Variable(torch.tensor(null_vector).to(
            dtype=torch.float, device=self.device),
                                    requires_grad=True)
        self.candidate_generator = WikiCandidateMentionGenerator(
            entity_world_path=None, max_candidates=max_candidates)

        if self.do_use_priors:
            self.null_bias = Variable(torch.zeros(
                (1, )).to(dtype=torch.float, device=self.device),
                                      requires_grad=True)

    def score_f1(self, true, pred):
        assert len(pred) == len(true)
        guessed, gold, correct = 0, 0, 0

        for t, p in zip(true, pred):
            if t == 0 and p == 0: pass
            elif t == 0 and p != 0: guessed += 1
            elif t != 0 and p == 0: gold += 1
            else:
                gold += 1
                guessed += 1
                if t == p: correct += 1

        prec_micro = 1.0 if guessed == 0 else correct / guessed
        rec_micro = 0.0 if gold == 0 else correct / gold
        f1_micro = 0 if prec_micro + rec_micro == 0 else 2 * prec_micro * rec_micro / (
            prec_micro + rec_micro)

        return prec_micro, rec_micro, f1_micro

    def data2samples(self, data, verbose=True):
        samples = []
        for data_idx, (sentence,
                       spans) in tqdm(data.items(),
                                      disable=not verbose,
                                      desc="Converting data to samples"):
            mentions = self.candidate_generator.get_mentions_raw_text(
                " ".join(sentence), whitespace_tokenize=True)
            span2candidates = {}

            for (start, end), entities, priors in zip(
                    mentions["candidate_spans"],
                    mentions["candidate_entities"],
                    mentions["candidate_entity_priors"]):
                if any([
                        x.startswith(self.ent_prefix) or x == "[PAD]"
                        for x in sentence[start:end + 1]
                ]):
                    continue

                normalized = [
                    normalize_entity(self.ebert_emb, entity, self.ent_prefix)
                    for entity in entities
                ]

                valid_entities = []
                valid_entities_set = set()
                valid_priors = []

                for entity, prior in zip(normalized, priors):
                    if entity is None or entity in valid_entities_set:
                        continue

                    valid_entities.append(entity)
                    valid_priors.append(prior)
                    valid_entities_set.add(entity)

                biases = [None] + [np.log(x) for x in valid_priors]
                entities = [None] + valid_entities

                if len(entities) > 1:
                    span2candidates[(start, end)] = (entities, biases)

            span2gold = {(start, end):
                         normalize_entity(self.ebert_emb, entity,
                                          self.ent_prefix)
                         for entity, start, end in spans}

            for (start, end), (candidates, biases) in span2candidates.items():
                assert not ("[CLS]" in sentence or "[MASK]" in sentence
                            or "[SEP]" in sentence or "[UNK]" in sentence)

                gold_candidate = span2gold.get((start, end), None)
                if not gold_candidate in candidates:
                    continue

                correct_idx = candidates.index(gold_candidate)

                for entity in candidates:
                    if not entity in self.ent2idx:
                        self.ent2idx[entity] = len(self.ent2idx)

                sentence_with_pattern = sentence[:start] + self.left_pattern + sentence[
                    start:end + 1] + self.right_pattern + sentence[end + 1:]

                present_entities = [
                    token for token in sentence_with_pattern
                    if token.startswith(self.ent_prefix)
                ]
                sentence_with_pattern_ent = [
                    "[UNK]" if token.startswith(self.ent_prefix) else token
                    for token in sentence_with_pattern
                ]
                sample_tokenized = self.tokenizer.tokenize(
                    " ".join(sentence_with_pattern_ent))

                ent_pos = [
                    i for i, token in enumerate(sample_tokenized)
                    if token == "[UNK]"
                ]
                del_pos = [
                    i for i, token in enumerate(sample_tokenized)
                    if token == "[PAD]"
                ]

                assert len(ent_pos) == len(present_entities)
                for pos, ent in zip(ent_pos, present_entities):
                    sample_tokenized[pos] = ent

                for pos in sorted(del_pos, reverse=True):
                    del sample_tokenized[pos]

                if not self.do_use_priors:
                    biases = None

                mask_pos = sample_tokenized.index("[MASK]")

                sample_tokenized = ["[CLS]"] + sample_tokenized + ["[SEP]"]
                input_ids = self.tokenizer.convert_tokens_to_ids(
                    sample_tokenized)

                samples.append(
                    Sample(input_ids=input_ids,
                           tokenized=sample_tokenized,
                           mask_pos=sample_tokenized.index("[MASK]"),
                           correct_idx=correct_idx,
                           candidate_ids=[
                               self.ent2idx[candidate]
                               for candidate in candidates
                           ],
                           biases=biases,
                           sentence=sentence,
                           candidates=candidates,
                           start=start,
                           end=end,
                           data_idx=data_idx))

        if verbose:
            for i in self.rnd.randint(low=0, high=len(samples), size=(5, )):
                print(samples[i], flush=True)

        return samples

    def read_aida_file(self, f, ignore_gold=True):

        data, sentence, spans = {}, [], []
        flag = False
        doc_count = -1

        with open(f) as handle:
            for line in handle:
                line = line.strip()
                if len(line) == 0: continue

                if line.startswith("DOCSTART"):
                    doc_count += 1
                    in_doc_count = -1
                if (line == "*NL*" and self.granularity
                        == "paragraph") or line == "DOCEND":
                    if len(sentence):
                        in_doc_count += 1
                        if ignore_gold:
                            data[(doc_count, in_doc_count)] = [sentence, []]
                        else:
                            data[(doc_count, in_doc_count)] = [sentence, spans]
                    sentence, spans = [], []
                elif line == "*NL*":
                    pass
                elif line.startswith("DOCSTART"):
                    doc_count += 1
                elif line.startswith("MMEND"):
                    flag = False
                elif flag:
                    spans[-1][-1] += 1
                    sentence.append(line.split()[0])
                elif line.startswith("MMSTART"):
                    assert not flag
                    gold_entity = line.split()[-1]
                    spans.append(
                        [gold_entity,
                         len(sentence),
                         len(sentence) - 1])
                    flag = True
                else:
                    sentence.append(line.split()[0])

        if 1:
            delete, checked = set(), set()

            while len(delete) + len(checked) != len(data):
                for data_idx, (sentence, spans) in data.items():
                    if (not data_idx in delete) and (not data_idx in checked):
                        if len(
                                self.tokenizer.tokenize(" ".join(
                                    sentence + self.left_pattern +
                                    self.right_pattern))) >= self.max_len - 2:
                            midpoint = len(sentence) // 2
                            breaking = False
                            for i in range(0, midpoint - 5):
                                if breaking: break
                                for direction in (-1, 1):
                                    point = midpoint + (direction * i)
                                    if sentence[point] == "." and not any([
                                            x[1] <= point + 1
                                            and x[2] >= point + 1
                                            for x in spans
                                    ]):
                                        midpoint = point + 1
                                        breaking = True
                                        break

                            sentence_a, sentence_b = sentence[:
                                                              midpoint], sentence[
                                                                  midpoint:]

                            spans_a = [x for x in spans if x[2] < midpoint]
                            spans_b = [(x[0], x[1] - midpoint, x[2] - midpoint)
                                       for x in spans if x[2] > midpoint]

                            data[data_idx + (1, )] = [sentence_a, spans_a]
                            data[data_idx + (2, )] = [sentence_b, spans_b]
                            delete.add(data_idx)
                            break

                        else:
                            checked.add(data_idx)

            assert len(delete.intersection(checked)) == 0
            for data_idx in delete:
                del data[data_idx]

        return data

    def normalize_predictions(self, data, predictions, sort_by_null=True):
        data_idx_offsets = {}
        current_offset = 0
        text = []
        for data_idx in sorted(list(data.keys())):
            data_idx_offsets[data_idx] = current_offset
            current_offset += len(data[data_idx][0])
            text.extend(data[data_idx][0])

        span2prediction = {}
        for entity, data_idx, start, end, ep in predictions:
            if entity is None: continue

            start += data_idx_offsets[data_idx]
            end += data_idx_offsets[data_idx]
            assert not (start, end) in span2prediction
            assert ep[0][0] is None
            null_p = ep[0][1]
            ep.sort(key=lambda x: x[1], reverse=True)
            assert ep[0][0] == entity
            span2prediction[(start, end)] = (entity, text[start:end + 1],
                                             ep[0][1], null_p)

        if sort_by_null:
            spans_sorted = sorted(list(span2prediction.keys()),
                                  key=lambda x: span2prediction[x][3])
        else:
            spans_sorted = sorted(list(span2prediction.keys()),
                                  key=lambda x: span2prediction[x][2],
                                  reverse=True)

        blocked = set()
        for start, end in spans_sorted:
            for x in range(start, end + 1):
                if x in blocked:
                    del span2prediction[(start, end)]
                    break

            if (start, end) in span2prediction:
                blocked.update(range(start, end + 1))

        return span2prediction

    def _predict_sentence(self,
                          sentence,
                          batch_size=4):  #, gold_spans = None):
        self.ent2idx = {None: 0}

        samples = self.data2samples({(0, ): [sentence, []]}, verbose=False)
        predictions = self.pred_loop(samples,
                                     batch_size=batch_size,
                                     verbose=False)
        span2prediction = {(start, end): ents_and_probas
                           for _, _, start, end, ents_and_probas in predictions
                           }
        assert all(
            [span2prediction[key][0][0] is None for key in span2prediction])

        spans_sorted = sorted(list(span2prediction.keys()),
                              key=lambda x: span2prediction[x][0][1])

        blocked = set()
        for start, end in spans_sorted:
            for x in range(start, end + 1):
                if x in blocked:
                    del span2prediction[(start, end)]
                    break

            if (start, end) in span2prediction:
                blocked.update(range(start, end + 1))

        spans = []
        for start, end in span2prediction:
            probas = np.array([x[1] for x in span2prediction[(start, end)]])
            ent = span2prediction[(start, end)][probas.argmax()][0]
            if ent is not None:
                spans.append([ent, start, end, probas.argmax()])

        return spans

    def predict_sentence(self,
                         sentence,
                         batch_size=4,
                         iterations=1):  #, gold_spans = None):
        sentence = copy.deepcopy(sentence)

        spans = []
        for it in range(iterations):
            pred_spans = self._predict_sentence(
                sentence, batch_size=batch_size)  #, gold_spans = gold_spans)

            if len(pred_spans) == 0:
                break

            pred_spans.sort(key=lambda x: x[-1], reverse=True)

            if it + 1 < iterations:
                x = (it + 1) * (len(spans) +
                                len(pred_spans)) // iterations - len(spans)
                pred_spans = pred_spans[:max(x, 1)]

            spans.extend([x[:-1] for x in pred_spans])
            for entity, start, end, _ in pred_spans:
                assert not "[PAD]" in sentence[start:end + 1]
                assert not any([
                    x.startswith(self.ent_prefix)
                    for x in sentence[start:end + 1]
                ])
                sentence = sentence[:start] + [
                    self.ent_prefix + entity
                ] + ["[PAD]"] * (end - start) + sentence[end + 1:]

        return spans

    def predict_aida(self, in_file, out_file, batch_size=4, iterations=1):
        data = self.read_aida_file(in_file)  #, ignore_gold = False)

        predictions = {}
        for idx, (sentence, _) in tqdm(list(data.items()), desc="Prediction"):
            predictions[idx] = self.predict_sentence(
                sentence, batch_size=batch_size,
                iterations=iterations)  #, gold_spans = None)

        norm_predictions = []
        offset = 0
        for key in sorted(list(predictions.keys())):
            predictions[key].sort(key=lambda x: x[1])
            for pred, start, end in predictions[key]:
                surface = " ".join(data[key][0][start:end + 1])
                norm_predictions.append(
                    f"{start+offset}\t{end+offset}\t{pred}\t{surface}\n")
            offset += len(data[key][0])

        with open(out_file, "w") as whandle:
            whandle.write("".join(norm_predictions))

    def train(self,
              train_file,
              dev_file,
              model_dir,
              batch_size=128,
              eval_batch_size=4,
              gradient_accumulation_steps=16,
              verbose=True,
              epochs=15,
              warmup_proportion=0.1,
              lr=5e-5,
              do_reinit_lm=False,
              beta2=0.999):

        self.ent2idx = {None: 0}

        if do_reinit_lm:
            for module in self.model.cls.predictions.transform.modules():
                self.model._init_weights(module)

        train_data = self.read_aida_file(train_file, ignore_gold=False)
        dev_data = self.read_aida_file(dev_file, ignore_gold=False)
        train_samples = self.data2samples(train_data)
        dev_samples = self.data2samples(dev_data)

        self.parameters = [self.null_vector]
        optimizer_grouped_parameters = [{
            'params': [self.null_vector],
            'weight_decay': 0.01
        }, {
            'params': [],
            'weight_decay': 0.0
        }]

        if self.do_use_priors:
            self.parameters.append(self.null_bias)
            optimizer_grouped_parameters[1]["params"].append(self.null_bias)

        for n, p in self.model.named_parameters():
            i = 1 if any([nd in n for nd in self.NO_DECAY]) else 0
            optimizer_grouped_parameters[i]['params'].append(p)
            self.parameters.append(p)

        num_train_steps_per_epoch = len(train_samples) // batch_size + int(
            len(train_samples) % batch_size != 0)
        num_train_steps = epochs * num_train_steps_per_epoch

        if beta2:
            self.optimizer = AdamW(optimizer_grouped_parameters,
                                   lr=lr,
                                   betas=(0.9, beta2))
        else:
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=warmup_proportion *
                                              num_train_steps,
                                              t_total=num_train_steps)

        best_f1 = -1
        self.save(model_dir, epoch=0)

        for _ in trange(epochs, desc="Epoch", disable=not verbose):
            self.rnd.shuffle(train_samples)

            train_loss = self.train_loop(
                train_samples,
                batch_size=batch_size // gradient_accumulation_steps,
                gradient_accumulation_steps=gradient_accumulation_steps)

            prec, rec, f1 = self.eval_loop(dev_samples,
                                           batch_size=eval_batch_size)
            self.save(model_dir, epoch=_ + 1)

            if f1 > best_f1:
                best_f1 = f1
                print(
                    "\nNew best micro F1 in epoch {}! P R F1: {:.4} {:.4} {:.4}"
                    .format(_ + 1, prec, rec, f1),
                    flush=True)
                print(
                    "(This is an estimate. Use predict functions and the scorer for the real result.)",
                    flush=True)
                self.save(model_dir, epoch=None)

    def save(self, model_dir, epoch=None):
        f_model = os.path.join(model_dir,
                               "model.pth") if epoch is None else os.path.join(
                                   model_dir, f"model_{epoch}.pth")
        f_null_vector = os.path.join(
            model_dir, "null_vector.pth") if epoch is None else os.path.join(
                model_dir, f"null_vector_{epoch}.pth")
        torch.save(self.model.state_dict(), f_model)
        torch.save(self.null_vector, f_null_vector)

        if self.do_use_priors:
            f_null_bias = os.path.join(
                model_dir, "null_bias.pth") if epoch is None else os.path.join(
                    model_dir, f"null_bias_{epoch}.pth")
            torch.save(self.null_bias, f_null_bias)

    def load(self, model_dir, epoch=None):
        f_model = os.path.join(model_dir,
                               "model.pth") if epoch is None else os.path.join(
                                   model_dir, f"model_{epoch}.pth")
        f_null_vector = os.path.join(
            model_dir, "null_vector.pth") if epoch is None else os.path.join(
                model_dir, f"null_vector_{epoch}.pth")
        self.model.load_state_dict(torch.load(f_model))
        self.null_vector.data = torch.load(f_null_vector).data

        if self.do_use_priors:
            f_null_bias = os.path.join(
                model_dir, "null_bias.pth") if epoch is None else os.path.join(
                    model_dir, f"null_bias_{epoch}.pth")
            self.null_bias.data = torch.load(f_null_bias).data

    def make_input_dict(self, samples):
        if self.dynamic_max_len:
            max_len = max([len(sample.tokenized) for sample in samples])
        else:
            max_len = self.max_len

        input_ids = torch.zeros((len(samples), max_len)).to(dtype=torch.long)
        attention_mask = torch.zeros_like(input_ids)

        for i, sample in enumerate(samples):
            assert len(sample.tokenized) <= max_len
            assert len(sample.tokenized) == len(sample.input_ids)
            input_ids[i, :len(sample.input_ids)] = torch.tensor(
                sample.input_ids).to(dtype=input_ids.dtype)
            attention_mask[i, :len(sample.input_ids)] = 1

        input_embeddings = self.bert_emb(input_ids)
        #unk_id = self.tokenizer.vocab["[UNK]"]

        for i, sample in enumerate(samples):
            for j, token in enumerate(sample.tokenized):
                if j == sample.mask_pos:
                    assert token == "[MASK]"
                    if self.do_prime_mask:
                        assert sample.candidates[0] is None
                        candidate_embeddings = np.array([
                            self.ebert_emb[self.ent_prefix + ent]
                            for ent in sample.candidates[1:]
                        ])
                        input_embeddings[i, j, :] = torch.tensor(
                            np.mean(candidate_embeddings,
                                    0)).to(dtype=input_embeddings.dtype)

                elif token.startswith(self.ent_prefix):
                    input_embeddings[i, j, :] = torch.tensor(
                        self.ebert_emb[token]).to(dtype=input_embeddings.dtype)

        input_embeddings = input_embeddings.to(device=self.device)
        attention_mask = attention_mask.to(device=self.device)
        label_ids = torch.tensor([sample.correct_idx for sample in samples
                                  ]).to(dtype=torch.long, device=self.device)

        return {
            "input_ids": input_embeddings,
            "attention_mask": attention_mask,
            "label_ids": label_ids
        }

    def train_loop(self,
                   samples,
                   batch_size,
                   gradient_accumulation_steps,
                   verbose=True):
        return self.loop(
            samples=samples,
            batch_size=batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            mode="train",
            verbose=verbose)

    def eval_loop(self, samples, batch_size, verbose=True):
        return self.loop(samples=samples,
                         batch_size=batch_size,
                         gradient_accumulation_steps=1,
                         mode="eval",
                         verbose=verbose)

    def pred_loop(self, samples, batch_size, verbose=True):
        return self.loop(samples=samples,
                         batch_size=batch_size,
                         gradient_accumulation_steps=1,
                         mode="pred",
                         verbose=verbose)

    def loop(self,
             samples,
             batch_size,
             mode,
             gradient_accumulation_steps,
             verbose=1):
        assert mode in ("train", "eval", "pred")

        all_true, all_pred, all_losses, all_pred_spans = [], [], [], []

        if mode == "train":
            self.model.train()
        else:
            self.model.eval()

        idx2ent = {self.ent2idx[ent]: ent for ent in self.ent2idx}
        assert len(idx2ent) == len(self.ent2idx)
        assert sorted(list(idx2ent.keys())) == list(range(len(idx2ent)))

        entity_embedding = torch.zeros(
            (len(idx2ent), self.null_vector.shape[0]))
        entity_embedding[1:] = torch.tensor(self.ebert_emb[[
            self.ent_prefix + idx2ent[idx] for idx in range(1, len(idx2ent))
        ]])
        entity_embedding = entity_embedding.to(dtype=self.null_vector.dtype)

        for step, i in enumerate(
                trange(0,
                       len(samples),
                       batch_size,
                       desc=f"Iterations ({mode})",
                       disable=not verbose)):
            batch = samples[i:i + batch_size]

            all_true.extend([sample.correct_idx for sample in batch])
            mask_positions = [sample.mask_pos for sample in batch]

            input_dict = self.make_input_dict(batch)
            label_ids = input_dict.pop("label_ids")

            all_entities, ranges = [], []
            for j, sample in enumerate(batch):
                all_entities.extend(sample.candidate_ids[1:])
                ranges.append([
                    len(all_entities) - len(sample.candidate_ids[1:]),
                    len(all_entities)
                ])

            all_outputs = self.model(**input_dict)[0]
            outputs = torch.stack([
                all_outputs[j, position]
                for j, position in enumerate(mask_positions)
            ])

            batch_entity_embedding = entity_embedding[torch.tensor(
                all_entities)].to(
                    device=self.device
                )  # move entity embeddings for entire batch to GPU

            batch_loss = 0
            for j, sample in enumerate(batch):
                assert len(
                    sample.candidate_ids) == ranges[j][1] - ranges[j][0] + 1
                candidates_with_zero = torch.cat([
                    self.null_vector.unsqueeze(0),
                    batch_entity_embedding[ranges[j][0]:ranges[j][1]]
                ])
                logits = candidates_with_zero.matmul(outputs[j])

                if self.do_use_priors:
                    assert sample.biases[0] is None
                    biases = torch.tensor(sample.biases[1:]).to(
                        device=self.null_bias.device,
                        dtype=self.null_bias.dtype)
                    logits += torch.cat([self.null_bias, biases])

                probas = torch.softmax(logits, -1)

                if mode == "eval":
                    probas_numpy = probas.detach().cpu().numpy()
                    all_pred.append(probas_numpy.argmax())

                if mode == "pred":
                    probas_numpy = probas.detach().cpu().numpy()
                    entities = [None] + [
                        idx2ent[i] for i in sample.candidate_ids[1:]
                    ]
                    all_pred_spans.append(
                        (entities[probas_numpy.argmax()], sample.data_idx,
                         sample.start, sample.end, [
                             (ent, float(p))
                             for ent, p in zip(entities, probas_numpy)
                         ]))

                elif mode == "train":
                    sample_loss = -torch.log(probas[label_ids[j]])

                    batch_loss += sample_loss / len(batch)
                    all_losses.append(float(sample_loss.item()))

            if mode == "train":
                batch_loss.backward()

                if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.parameters, 1.0)
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()

        if mode == "pred":
            return all_pred_spans
        if mode == "eval":
            return self.score_f1(all_true, all_pred)
        elif mode == "train":
            return np.mean(all_losses)
Пример #8
0
    def test_wiki_linking_reader(self):

        fake_entity_world = {
            "Germany": "11867",
            "United_Kingdom": "31717",
            "European_Commission": "42336"
        }
        candidate_generator = WikiCandidateMentionGenerator(
            'tests/fixtures/linking/priors.txt',
            entity_world_path=fake_entity_world)
        train_file = 'tests/fixtures/linking/aida.txt'

        reader = LinkingReader(mention_generator=candidate_generator)
        instances = reader.read(train_file)

        instances = list(instances)

        fields = instances[0].fields

        text = [x.text for x in fields["tokens"].tokens]
        assert text == [
            'EU', 'rejects', 'German', 'call', 'to', 'boycott', 'British',
            'lamb', '.'
        ]

        spans = fields["candidate_spans"].field_list
        span_starts, span_ends = zip(*[(field.span_start, field.span_end)
                                       for field in spans])
        assert span_starts == (6, 2)
        assert span_ends == (6, 2)
        gold_ids = [x.text for x in fields["gold_entities"].tokens]
        assert gold_ids == ['United_Kingdom', 'Germany']

        candidate_token_list = [
            x.text for x in fields["candidate_entities"].tokens
        ]
        candidate_tokens = []
        for x in candidate_token_list:
            candidate_tokens.extend(x.split(" "))

        assert candidate_tokens == ['United_Kingdom', 'Germany']

        numpy.testing.assert_array_almost_equal(
            fields["candidate_entity_prior"].array, numpy.array([[1.], [1.]]))
        fields = instances[1].fields
        text = [x.text for x in fields["tokens"].tokens]
        assert text == [
            'The', 'European', 'Commission', 'said', 'on', 'Thursday', 'it',
            'disagreed', 'with', 'German', 'advice', 'to', 'consumers', 'to',
            'shun', 'British', 'lamb', 'until', 'scientists', 'determine',
            'whether', 'it', 'is', 'dangerous'
        ]

        spans = fields["candidate_spans"].field_list
        span_starts, span_ends = zip(*[(field.span_start, field.span_end)
                                       for field in spans])
        assert span_starts == (15, 9)
        assert span_ends == (15, 9)
        gold_ids = [x.text for x in fields["gold_entities"].tokens]
        # id not inside our mini world, should be ignored
        assert "European_Commission" not in gold_ids
        assert gold_ids == ['United_Kingdom', 'Germany']
        candidate_token_list = [
            x.text for x in fields["candidate_entities"].tokens
        ]
        candidate_tokens = []
        for x in candidate_token_list:
            candidate_tokens.extend(x.split(" "))
        assert candidate_tokens == ['United_Kingdom', 'Germany']

        numpy.testing.assert_array_almost_equal(
            fields["candidate_entity_prior"].array, numpy.array([[1.], [1.]]))