Example #1
0
def test_forward_pass():
    data = MaskedBatchedExamples.build(
        [
            Example(
                Words.build(torch.IntTensor([232, 13, 13, 21, 5, 1]),
                            spans=[[0, 1], [1, 3], [5, 6]]),
                Entities.build(torch.IntTensor([5]), [[1, 3]], 128, 30),
            ),
            Example(
                Words.build(torch.IntTensor(
                    [29, 28, 5000, 22, 11, 55, 1, 1, 1, 1]),
                            spans=[[0, 1], [1, 2], [2, 5], [5, 10]]),
                Entities.build(torch.IntTensor([11, 5]), [[0, 2], [2, 10]],
                               128, 30),
            ),
        ],
        torch.device("cpu"),
        42,
        11,
        0.1,
        0.1,
        0.1,
        (5, 42),
        0.1,
    )
    cfg = AutoConfig.from_pretrained(daBERT)
    model = PretrainTaskDaLUKE(cfg, 66, 79)
    word_scores, ent_scores = model(data)
    assert word_scores.shape[0] >= 2
    assert ent_scores.shape[0] == 2

    assert word_scores.shape[1] == 32_000
    assert ent_scores.shape[1] == 66
Example #2
0
def example_from_str(
        text:             str,
        entity_spans:     list[tuple[int, int]],
        daluke:           AutoDaLUKE,
    ) -> BatchedExamples:
    subword_ids = get_subword_ids(text, daluke.tokenizer)
    sep, cls_, pad, _, _ = get_special_ids(daluke.tokenizer)
    flat_subword_ids = list(chain(*subword_ids))
    # Reduce the word ids to lower vocab if we use monoliguification of multilingual model
    if daluke.token_map is not None:
        flat_subword_ids = daluke.token_map[flat_subword_ids]

    w = Words.build(
        ids     = torch.IntTensor(flat_subword_ids),
        max_len = daluke.metadata["max-seq-length"],
        sep_id  = sep,
        cls_id  = cls_,
        pad_id  = pad,
    )
    e = Entities.build(
        ids             = get_entity_id_tensor(text, entity_spans, daluke.entity_vocab),
        spans           = get_entity_subword_spans(subword_ids, entity_spans),
        max_entities    = daluke.metadata["max-entities"],
        max_entity_span = daluke.metadata["max-entity-span"],
    )
    return BatchedExamples.build(
        [
            Example(
                words    = w,
                entities = e,
            ),
        ],
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    )
Example #3
0
def features_from_str(words: list[str], entity_spans: list[tuple[int, int]],
                      entity_vocab: dict[str, int],
                      tokenizer: AutoTokenizer) -> Example:
    word_ids = torch.IntTensor(tokenizer.convert_tokens_to_ids(words))
    ents = (" ".join(words[e[0]:e[1]]) for e in entity_spans)
    ent_ids = torch.IntTensor(
        [entity_vocab.get(ent, entity_vocab["[UNK]"]) for ent in ents])
    return Example(words=Words.build(word_ids),
                   entities=Entities.build(ent_ids, entity_spans, 128, 30))
Example #4
0
def test_words():
    words = Words.build(
        torch.IntTensor([22, 48, 99]),
        max_len=10,
    )
    assert torch.equal(words.ids,
                       torch.IntTensor([2, 22, 48, 99, 3, 0, 0, 0, 0, 0]))
    assert torch.equal(words.attention_mask,
                       torch.IntTensor([1, 1, 1, 1, 1, 0, 0, 0, 0, 0]))
Example #5
0
    def build_examples(self) -> tuple[list[Example], list[Example]]:
        train_examples, val_examples = list(), list()
        with open(os.path.join(self.data_dir, DatasetBuilder.data_file)) as f,\
            TT.profile("Build example", hits=self.metadata["number-of-items"]):
            for seq_data in load_jsonl(f):
                is_validation = seq_data["is_validation"]
                if self.only_load_validation and not is_validation:
                    continue
                if self.ent_min_mention:
                    # Keep only entities in filtered entity vocab
                    seq_data["entity_spans"] = [
                        span for id_, span in zip(seq_data["entity_ids"],
                                                  seq_data["entity_spans"])
                        if id_ in self.ent_ids
                    ]
                    seq_data["entity_ids"] = [
                        id_ for id_ in seq_data["entity_ids"]
                        if id_ in self.ent_ids
                    ]

                ex = Example(
                    words=Words.build(
                        torch.IntTensor(seq_data["word_ids"]),
                        seq_data["word_spans"],
                        max_len=self.max_sentence_len,
                        pad_id=self.pad_id,
                    ),
                    entities=Entities.build(
                        torch.IntTensor(seq_data["entity_ids"]),
                        seq_data["entity_spans"],
                        max_entities=self.max_entities,
                        max_entity_span=self.max_entity_span,
                    ),
                )
                if is_validation:
                    val_examples.append(ex)
                else:
                    train_examples.append(ex)

        return train_examples, val_examples
Example #6
0
    def _build_examples(self, split: Split) -> list[NERExample]:
        examples = list()

        for i, (text, annotation, bounds) in enumerate(
                zip(self.data[split].texts, self.data[split].annotations,
                    self.data[split].sentence_boundaries)):
            text_token_ids: list[list[int]] = self.tokenizer(
                text, add_special_tokens=False)["input_ids"]
            # We might have to split some sentences to respect the maximum sentence length
            bounds = self._add_extra_sentence_boundaries(
                bounds, text_token_ids)
            # TODO: Consider the current sentence splitting: Do we throw away context in situations where we actually have sentence-document information? (Not relevant for DaNE)
            for j, end in enumerate(bounds):
                start = bounds[j - 1] if j else 0
                # Flatten structure of [[subwords], [subwords], ... ]
                word_ids = list(chain(*text_token_ids[start:end]))
                # Reduce the word ids to lower vocab if we use monoliguification of multilingual model
                if self.token_map is not None:
                    word_ids = self.token_map[word_ids]
                true_entity_fullword_spans = self._segment_entities(
                    annotation[start:end])
                # The cumulative length of each word in units of subwords
                cumlength = np.cumsum(
                    [len(t) for t in text_token_ids[start:end]])
                # Save the spans of entities as they are in the token list
                true_entity_subword_spans = {
                    (cumlength[s - 1] + 1 if s else 1, cumlength[e - 1] + 1):
                    ann
                    for (s, e), ann in true_entity_fullword_spans.items()
                }  # +1 for CLS token
                assert all(e-s <= self.max_entity_span for s, e in true_entity_subword_spans),\
                        f"Example {i}, sentence {j} contains an entity longer than limit of {self.max_entity_span} tokens. Text:\n\t{text}"
                assert len(true_entity_subword_spans) < self.max_entities,\
                        f"Example {i}, sentence {j} contains {len(true_entity_subword_spans)} entities, but only {self.max_entities} are allowed. Text:\n\t{text}"

                all_entity_fullword_spans = self._generate_all_entity_spans(
                    true_entity_fullword_spans, text_token_ids[start:end],
                    cumlength)
                # +1 for CLS token
                all_entity_subword_spans = [
                    (cumlength[s - 1] + 1 if s else 1, cumlength[e - 1] + 1)
                    for s, e in all_entity_fullword_spans
                ]

                # We dont use the entity id: We just use code them as masked corresponding to ID 1, as we mutated this to be the case
                entity_ids = torch.ones(len(all_entity_subword_spans),
                                        dtype=torch.int)
                entity_labels = torch.LongTensor([
                    self.label_to_idx[true_entity_subword_spans.get(
                        span, self.null_label)]
                    for span in all_entity_subword_spans
                ])
                # If there are too many possible spans for self.max_entities, we must divide the sequence into multiple examples
                for sub_example in range(
                        int(
                            math.ceil(
                                len(all_entity_subword_spans) /
                                self.max_entities))):
                    substart = self.max_entities * sub_example
                    subend = self.max_entities * (sub_example + 1)

                    entities = Entities.build(
                        torch.IntTensor(entity_ids[substart:subend]),
                        torch.IntTensor(
                            all_entity_subword_spans[substart:subend]),
                        max_entities=self.max_entities,
                        max_entity_span=self.max_entity_span,
                    )
                    words = Words.build(
                        torch.IntTensor([self.cls_id, *word_ids, self.sep_id]),
                        max_len=self.max_seq_length,
                        pad_id=self.pad_id,
                    )
                    examples.append(
                        NERExample(
                            words=words,
                            entities=NEREntities.build_from_entities(
                                entities,
                                fullword_spans=all_entity_fullword_spans[
                                    substart:subend],
                                labels=entity_labels[substart:subend],
                                max_entities=self.max_entities,
                            ),
                            text_num=i,
                        ))
            # Handy for debugging on smaller data set
            if self.data_limit is not None and i == self.data_limit:
                break
        return examples
Example #7
0
def test_word_masking():
    w = Words(ids=torch.IntTensor([
        [42] * 10,
        [69, 5, 60, 60, 3] + [-1] * 5,
    ]),
              attention_mask=None,
              N=torch.IntTensor([10, 5]),
              spans=[
                  torch.IntTensor([[0, 10]]),
                  torch.IntTensor([[0, 1], [2, 4]]),
              ])
    w1 = deepcopy(w)
    _, mask = mask_word_batch(
        w1,
        0.5,
        0.25,
        0.25,
        (5, 29),
        999,
    )
    assert sum(mask[0]) == 10  # Entire 10-long word must be masked
    assert sum(mask[1]) in (
        1, 2
    )  # Either the first 1-long word or the second 2-long word must be masked
    assert sum(w1.ids.ravel() == 999) <= 12

    w2 = deepcopy(w)
    _, mask = mask_word_batch(
        w2,
        1,
        1,
        0,
        (5, 29),
        999,
    )
    assert sum(mask[0]) == 10  # Entire 10-long word must be masked
    assert sum(mask[1]) == 3  # Both the 1-long and 2-long are masked
    assert torch.equal(w2.ids, w.ids)  # Everything must be unmasked

    w3 = deepcopy(w)
    _, mask = mask_word_batch(
        w3,
        1,
        0,
        1,
        (69, 70),
        999,
    )
    assert sum(mask[0]) == 10
    assert sum(mask[1]) == 3
    assert torch.equal(
        w3.ids,
        torch.IntTensor([[69, 69, 69, 69, 69, 69, 69, 69, 69, 69],
                         [69, 5, 69, 69, 3, -1, -1, -1, -1, -1]]))

    w4 = deepcopy(w)
    _, mask = mask_word_batch(
        w4,
        0.5,
        0,
        0,
        (0, 1),
        999,
    )
    assert sum(w4.ids[0] == 999) == 10
    assert sum(w4.ids[1] == 999) in (1, 2)