Exemplo n.º 1
0
class SlotLabelTransform(Transform):
    def __init__(self, poss_slots: List[str], tokenizer: nn.Module = None):
        super().__init__()
        self.NO_LABEL = Token("NoLabel")
        poss_slots = list(poss_slots)
        if self.NO_LABEL not in poss_slots:
            poss_slots.insert(0, self.NO_LABEL)
        if SpecialTokens.PAD not in poss_slots:
            poss_slots.insert(1, SpecialTokens.PAD)
        if SpecialTokens.UNK not in poss_slots:
            poss_slots.insert(2, SpecialTokens.UNK)
        self.vocab = Vocabulary(poss_slots)

    def process_slots(self, slots_list: str) -> List[Slot]:
        if "," in slots_list:
            slots_list = slots_list.split(",")
        elif slots_list != "":
            slots_list = [slots_list]
        else:
            return []
        slot_labels: List[Slot] = []
        for curr_slot in slots_list:
            first_delim = curr_slot.find(":")
            second_delim = curr_slot.find(":", first_delim + 1)
            start_ind = int(curr_slot[0:first_delim])
            end_ind = int(curr_slot[first_delim + 1:second_delim])
            slot_name = curr_slot[second_delim + 1:]
            slot_labels.append(Slot(slot_name, start_ind, end_ind))
        return slot_labels

    def forward(self, text_and_slots):
        """
        Turn slot labels and text into a list of token labels with the same
        length as the number of tokens in the text.
        """
        tokens, start, end = text_and_slots[0].values()
        slots = self.process_slots(text_and_slots[1])
        curr_slot_i = 0
        curr_token_i = 0
        slot_labels: List[str] = []
        while curr_token_i < len(tokens) and curr_slot_i < len(slots):
            curr_slot = slots[curr_slot_i]
            if int(start[curr_token_i]) > curr_slot.end:
                curr_slot_i += 1
            else:
                if int(end[curr_token_i]) > curr_slot.start:
                    slot_labels.append(curr_slot.label)
                else:
                    slot_labels.append(self.NO_LABEL)
                curr_token_i += 1
        slot_labels += [self.NO_LABEL] * (len(tokens) - curr_token_i)
        slot_label_idx = self.vocab.lookup_all(slot_labels)
        return {"slot_labels": torch.tensor(slot_label_idx)}

    @property
    def is_jitable(self) -> bool:
        return False
Exemplo n.º 2
0
class LabelTransform(Transform):
    def __init__(self, label_names: List[str]):
        super().__init__()
        self.vocab = Vocabulary(label_names)

    def forward(self, label: str) -> Dict[str, torch.Tensor]:
        label_id = self.vocab.lookup_all(label)
        return {"label_ids": torch.tensor(label_id, dtype=torch.long)}

    @property
    def is_jitable(self) -> bool:
        return False
Exemplo n.º 3
0
class LabelTransform(Transform):
    def __init__(self, label_names: List[str]):
        super().__init__()
        if SpecialTokens.UNK not in label_names:
            label_names.insert(0, SpecialTokens.UNK)
        self.vocab = Vocabulary(label_names)

    def forward(self, label: str) -> Dict[str, torch.Tensor]:
        label_id = self.vocab.lookup_all(label)
        return {"label_ids": torch.tensor(label_id, dtype=torch.long)}

    @property
    def is_jitable(self) -> bool:
        return False

    @property
    def labels(self) -> Dict[str, int]:
        return self.vocab.idx
Exemplo n.º 4
0
 def _run_benchmark_pytext_vocab(toks, v: PytextVocabulary):
     for token_or_tokens_list in toks:
         v.lookup_all(token_or_tokens_list)