Beispiel #1
0
    def _set_data_desc_to_cfg(self, cfg: DictConfig, data_dir: str,
                              train_ds: DictConfig,
                              validation_ds: DictConfig) -> None:
        """ 
        Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. 
        
        Args: 
            cfg: configuration object
            data_dir: data directory 
            train_ds: training dataset file name
            validation_ds: validation dataset file name

        Returns:
            None
        """
        # Save data from data desc to config - so it can be reused later, e.g. in inference.
        data_desc = MultiLabelIntentSlotDataDesc(
            data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix])
        OmegaConf.set_struct(cfg, False)
        if not hasattr(cfg, "data_desc") or cfg.data_desc is None:
            cfg.data_desc = {}
        # Intents.
        cfg.data_desc.intent_labels = list(data_desc.intents_label_ids.keys())
        cfg.data_desc.intent_label_ids = data_desc.intents_label_ids
        cfg.data_desc.intent_weights = data_desc.intent_weights
        # Slots.
        cfg.data_desc.slot_labels = list(data_desc.slots_label_ids.keys())
        cfg.data_desc.slot_label_ids = data_desc.slots_label_ids
        cfg.data_desc.slot_weights = data_desc.slot_weights

        cfg.data_desc.pad_label = data_desc.pad_label

        # for older(pre - 1.0.0.b3) configs compatibility
        if not hasattr(cfg, "class_labels") or cfg.class_labels is None:
            cfg.class_labels = {}
            cfg.class_labels = OmegaConf.create({
                "intent_labels_file":
                "intent_labels.csv",
                "slot_labels_file":
                "slot_labels.csv",
            })

        slot_labels_file = os.path.join(data_dir,
                                        cfg.class_labels.slot_labels_file)
        intent_labels_file = os.path.join(data_dir,
                                          cfg.class_labels.intent_labels_file)
        self._save_label_ids(data_desc.slots_label_ids, slot_labels_file)
        self._save_label_ids(data_desc.intents_label_ids, intent_labels_file)

        self.register_artifact("class_labels.intent_labels_file",
                               intent_labels_file)
        self.register_artifact("class_labels.slot_labels_file",
                               slot_labels_file)
        OmegaConf.set_struct(cfg, True)