class SlotLabelTransform(Transform): def __init__(self, poss_slots: List[str], tokenizer: nn.Module = None): super().__init__() self.NO_LABEL = Token("NoLabel") poss_slots = list(poss_slots) if self.NO_LABEL not in poss_slots: poss_slots.insert(0, self.NO_LABEL) if SpecialTokens.PAD not in poss_slots: poss_slots.insert(1, SpecialTokens.PAD) if SpecialTokens.UNK not in poss_slots: poss_slots.insert(2, SpecialTokens.UNK) self.vocab = Vocabulary(poss_slots) def process_slots(self, slots_list: str) -> List[Slot]: if "," in slots_list: slots_list = slots_list.split(",") elif slots_list != "": slots_list = [slots_list] else: return [] slot_labels: List[Slot] = [] for curr_slot in slots_list: first_delim = curr_slot.find(":") second_delim = curr_slot.find(":", first_delim + 1) start_ind = int(curr_slot[0:first_delim]) end_ind = int(curr_slot[first_delim + 1:second_delim]) slot_name = curr_slot[second_delim + 1:] slot_labels.append(Slot(slot_name, start_ind, end_ind)) return slot_labels def forward(self, text_and_slots): """ Turn slot labels and text into a list of token labels with the same length as the number of tokens in the text. """ tokens, start, end = text_and_slots[0].values() slots = self.process_slots(text_and_slots[1]) curr_slot_i = 0 curr_token_i = 0 slot_labels: List[str] = [] while curr_token_i < len(tokens) and curr_slot_i < len(slots): curr_slot = slots[curr_slot_i] if int(start[curr_token_i]) > curr_slot.end: curr_slot_i += 1 else: if int(end[curr_token_i]) > curr_slot.start: slot_labels.append(curr_slot.label) else: slot_labels.append(self.NO_LABEL) curr_token_i += 1 slot_labels += [self.NO_LABEL] * (len(tokens) - curr_token_i) slot_label_idx = self.vocab.lookup_all(slot_labels) return {"slot_labels": torch.tensor(slot_label_idx)} @property def is_jitable(self) -> bool: return False
class LabelTransform(Transform): def __init__(self, label_names: List[str]): super().__init__() self.vocab = Vocabulary(label_names) def forward(self, label: str) -> Dict[str, torch.Tensor]: label_id = self.vocab.lookup_all(label) return {"label_ids": torch.tensor(label_id, dtype=torch.long)} @property def is_jitable(self) -> bool: return False
class LabelTransform(Transform): def __init__(self, label_names: List[str]): super().__init__() if SpecialTokens.UNK not in label_names: label_names.insert(0, SpecialTokens.UNK) self.vocab = Vocabulary(label_names) def forward(self, label: str) -> Dict[str, torch.Tensor]: label_id = self.vocab.lookup_all(label) return {"label_ids": torch.tensor(label_id, dtype=torch.long)} @property def is_jitable(self) -> bool: return False @property def labels(self) -> Dict[str, int]: return self.vocab.idx
def _run_benchmark_pytext_vocab(toks, v: PytextVocabulary): for token_or_tokens_list in toks: v.lookup_all(token_or_tokens_list)