Пример #1
0
    def __init__(self, cfg: DictConfig, trainer: Trainer = None):
        """Initializes Token Classification Model."""
        # extract str to int labels mapping if a mapping file provided
        if isinstance(cfg.label_ids, str):
            if os.path.exists(cfg.label_ids):
                logging.info(
                    f'Reusing label_ids file found at {cfg.label_ids}.')
                label_ids = get_labels_to_labels_id_mapping(cfg.label_ids)
                # update the config to store name to id mapping
                cfg.label_ids = OmegaConf.create(label_ids)
            else:
                raise ValueError(f'{cfg.label_ids} not found.')

        self.class_weights = None
        super().__init__(cfg=cfg, trainer=trainer)

        self.classifier = TokenClassifier(
            hidden_size=self.hidden_size,
            num_classes=len(self._cfg.label_ids),
            num_layers=self._cfg.head.num_fc_layers,
            activation=self._cfg.head.activation,
            log_softmax=False,
            dropout=self._cfg.head.fc_dropout,
            use_transformer_init=self._cfg.head.use_transformer_init,
        )

        self.loss = self.setup_loss(
            class_balancing=self._cfg.dataset.class_balancing)

        # setup to track metrics
        self.classification_report = ClassificationReport(
            len(self._cfg.label_ids),
            label_ids=self._cfg.label_ids,
            dist_sync_on_step=True)
Пример #2
0
    def __init__(
        self,
        data_dir: str,
        modes: List[str] = ["train", "test", "dev"],
        none_slot_label: str = "O",
        pad_label: int = -1,
    ):
        if not if_exist(data_dir, ["dict.intents.csv", "dict.slots.csv"]):
            raise FileNotFoundError(
                "Make sure that your data follows the standard format "
                "supported by MultiLabelIntentSlotDataset. Your data must "
                "contain dict.intents.csv and dict.slots.csv.")

        self.data_dir = data_dir
        self.intent_dict_file = self.data_dir + "/dict.intents.csv"
        self.slot_dict_file = self.data_dir + "/dict.slots.csv"

        self.intents_label_ids = get_labels_to_labels_id_mapping(
            self.intent_dict_file)
        self.num_intents = len(self.intents_label_ids)
        self.slots_label_ids = get_labels_to_labels_id_mapping(
            self.slot_dict_file)
        self.num_slots = len(self.slots_label_ids)

        infold = self.data_dir
        for mode in modes:
            if not if_exist(self.data_dir, [f"{mode}.tsv"]):
                logging.info(f" Stats calculation for {mode} mode"
                             f" is skipped as {mode}.tsv was not found.")
                continue
            logging.info(f" Stats calculating for {mode} mode...")
            slot_file = f"{self.data_dir}/{mode}_slots.tsv"
            with open(slot_file, "r") as f:
                slot_lines = f.readlines()

            input_file = f"{self.data_dir}/{mode}.tsv"
            with open(input_file, "r") as f:
                input_lines = f.readlines()[1:]  # Skipping headers at index 0

            if len(slot_lines) != len(input_lines):
                raise ValueError(
                    "Make sure that the number of slot lines match the "
                    "number of intent lines. There should be a 1-1 "
                    "correspondence between every slot and intent lines.")

            dataset = list(zip(slot_lines, input_lines))

            raw_slots, raw_intents = [], []
            for slot_line, input_line in dataset:
                slot_list = [int(slot) for slot in slot_line.strip().split()]
                raw_slots.append(slot_list)
                parts = input_line.strip().split("\t")[1:][0]
                parts = list(map(int, parts.split(",")))
                parts = [
                    1 if label in parts else 0
                    for label in range(self.num_intents)
                ]
                raw_intents.append(tuple(parts))

            logging.info(f"Three most popular intents in {mode} mode:")
            total_intents, intent_label_freq, max_id = get_multi_label_stats(
                raw_intents, infold + f"/{mode}_intent_stats.tsv")

            merged_slots = itertools.chain.from_iterable(raw_slots)
            logging.info(f"Three most popular slots in {mode} mode:")
            slots_total, slots_label_freq, max_id = get_label_stats(
                merged_slots, infold + f"/{mode}_slot_stats.tsv")

            logging.info(f"Total Number of Intent Labels: {total_intents}")
            logging.info(f"Intent Label Frequencies: {intent_label_freq}")
            logging.info(f"Total Number of Slots: {slots_total}")
            logging.info(f"Slots Label Frequencies: {slots_label_freq}")

            if mode == "train":
                intent_weights_dict = get_freq_weights_bce_with_logits_loss(
                    intent_label_freq)
                logging.info(f"Intent Weights: {intent_weights_dict}")
                slot_weights_dict = get_freq_weights(slots_label_freq)
                logging.info(f"Slot Weights: {slot_weights_dict}")

        self.intent_weights = fill_class_weights(intent_weights_dict,
                                                 self.num_intents - 1)
        self.slot_weights = fill_class_weights(slot_weights_dict,
                                               self.num_slots - 1)

        if pad_label != -1:
            self.pad_label = pad_label
        else:
            if none_slot_label not in self.slots_label_ids:
                raise ValueError(f"none_slot_label {none_slot_label} not "
                                 f"found in {self.slot_dict_file}.")
            self.pad_label = self.slots_label_ids[none_slot_label]