示例#1
0
    def bind(self,
             tensors_ref: List["NmTensor"],
             port_names: Optional[str] = None):
        """
            Binds the "default" outputs.

            Args:
                tensors_ref: List of tensors to be added.
                port_names: List of port names (visible outside). If None: using internal tensor "output port names".
        """
        # Set names.
        if port_names is None:
            port_names = [tensor.name for tensor in tensors_ref]

        for name, tensor in zip(port_names, tensors_ref):
            # Check the presence of the port name in "default" dictionary.
            if name in self._default_outputs.keys():
                # Name present - use the name being combination of producer and port names.
                name = (str(tensor.producer_step_number) + "_" +
                        tensor.producer_name + "_" + tensor.name
                        )  # last = port name

                logging.debug(
                    "Setting unique name of the default output port `{}` produced in step {} by `{}` to `{}`"
                    .format(tensor.name, tensor.producer_step_number,
                            tensor.producer_name, name))
            # Store the output.
            self._default_outputs[name] = GraphOutput(
                tensor.ntype, tensor.producer_step_module_port)
    def forward(self, user_uttr, dialog_history):
        """
        Returns dialogue utterances in the format accepted by the TRADE Dialogue state tracking model
        Args:
            dialog_history (list): dialogue history, list of system and diaglogue utterances
            user_uttr (str): user utterance
        Returns:
            dialog_ids (int): token ids for the whole dialogue history
            dialog_lens (int): length of the whole tokenized dialogue history
            dialog_history (list): updated dialogue history, list of system and diaglogue utterances
        """
        # TODO: why we update sys utterance, whereas we have only user utterance at that point?
        dialog_history.append(["user", user_uttr])
        logging.debug("Dialogue history: %s", dialog_history)

        context = ' ; '.join(
            [item[1].strip().lower()
             for item in dialog_history]).strip() + ' ;'
        context_ids = self.data_desc.vocab.tokens2ids(context.split())
        dialog_ids = torch.tensor(context_ids).unsqueeze(0).to(self._device)
        dialog_lens = torch.tensor(len(context_ids)).unsqueeze(0).to(
            self._device)

        # logging.debug("!! dialog_ids: %s", dialog_ids)
        # logging.debug("!! dialog_lens: %s", dialog_lens)

        return dialog_ids, dialog_lens, dialog_history
示例#3
0
    def __init__(
        self,
        *,
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        time_length: Optional[float] = 8,
        shift_length: Optional[float] = 1,
        normalize_audio: bool = False,
        is_regression_task: bool = False,
    ):
        self.time_length = time_length
        self.shift_length = shift_length
        self.normalize_audio = normalize_audio

        logging.debug("Time length considered for collate func is {}".format(self.time_length))
        logging.debug("Shift length considered for collate func is {}".format(self.shift_length))

        super().__init__(
            manifest_filepath=manifest_filepath,
            labels=labels,
            featurizer=featurizer,
            min_duration=min_duration,
            max_duration=max_duration,
            trim=trim,
            is_regression_task=is_regression_task,
        )
示例#4
0
def get_dataset_as_dict(file_path_patterns) -> dict:
    """Read the DSTC8/SGD json dialogue data as dictionary with dialog ID as keys.
    Args:
        file_path_patterns: list or directory of files 
    Returns:
        dataset_dict: dataset dictionary with dialog ID as keys
    """
    dataset_dict = {}
    if isinstance(file_path_patterns, list):
        list_fp = file_path_patterns
    else:
        list_fp = sorted(glob.glob(file_path_patterns))
    for fp in list_fp:
        if PER_FRAME_OUTPUT_FILENAME in fp:
            continue
        logging.debug("Loading file: %s", fp)
        with open(fp, encoding="UTF-8") as f:
            data = json.load(f)
            if isinstance(data, list):
                for dial in data:
                    dataset_dict[dial["dialogue_id"]] = dial
            elif isinstance(data, dict):
                dataset_dict.update(data)
            f.close()
    return dataset_dict
示例#5
0
    def __init__(
        self,
        *,
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        load_audio: bool = True,
    ):
        super().__init__()
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
        )

        self.featurizer = featurizer
        self.trim = trim
        self.load_audio = load_audio

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))
示例#6
0
    def __init__(
        self,
        always_save_nemo=False,
        save_nemo_on_train_end=True,
        save_best_model=False,
        postfix=".nemo",
        n_resume=False,
        model_parallel_size=None,
        **kwargs,
    ):
        # Parse and store "extended" parameters: save_best model and postfix.
        self.always_save_nemo = always_save_nemo
        self.save_nemo_on_train_end = save_nemo_on_train_end
        self.save_best_model = save_best_model
        if self.save_best_model and not self.save_nemo_on_train_end:
            logging.warning((
                "Found save_best_model is True and save_nemo_on_train_end is False. "
                "Set save_nemo_on_train_end to True to automatically save the best model."
            ))
        self.postfix = postfix
        self.previous_best_path = ""
        self.model_parallel_size = model_parallel_size

        # `prefix` is deprecated
        if 'prefix' in kwargs:
            self.prefix = kwargs.pop('prefix')
        else:
            self.prefix = ""

        # Call the parent class constructor with the remaining kwargs.
        super().__init__(**kwargs)

        if self.save_top_k != -1 and n_resume:
            logging.debug("Checking previous runs")
            self.nemo_topk_check_previous_run()
示例#7
0
    def unfreeze(self, module_names: Optional[List[str]] = None):
        """
        Unfreezes weights of the trainable modules in a graph.

        Args:
            module_names: List of modules to be unfrozen (Optional). If not passed, all modules will be unfrozen.
        Raises:
            KeyError: If name of the module won't be recognized.
        """
        # Work on all modules.
        if module_names is None:
            module_names = self._modules.keys()

        # Iterate through modules one by one.
        for name in module_names:
            if name not in self._modules.keys():
                raise KeyError(
                    "Module `{}` not present in the `{}` graph".format(
                        name, self.name))
            # Check module type.
            module = self._modules[name]
            if module.type == ModuleType.trainable:
                # Unfreeze weights of the module.
                module.unfreeze()
            else:
                logging.debug(
                    "Module `{}` is not trainable so cannot be unfrozen".
                    format(name))
示例#8
0
    def __init__(self, schema_json_paths: Union[str, List[str]]):
        """
        schema_json_paths: list of .json path to schema files of a single str with path to the json file.
        """
        # Load the schema from the json file.
        if isinstance(schema_json_paths, str):
            with open(schema_json_paths, "r") as f:
                all_schemas = json.load(f)
                f.close()
        else:
            # load multiple schemas from the list of the json files
            all_schemas = []
            completed_services = []
            for schema_json_path in schema_json_paths:
                with open(schema_json_path, "r") as f:
                    schemas = json.load(f)
                    f.close()
                    logging.debug("Num of services in %s: %s", schema_json_path, len(schemas))

                for service in schemas:
                    if service['service_name'] not in completed_services:
                        completed_services.append(service['service_name'])
                        all_schemas.append(service)

        self._services = sorted(schema["service_name"] for schema in all_schemas)
        self._services_vocab = {v: k for k, v in enumerate(self._services)}
        self._services_id_to_vocab = {v: k for k, v in self._services_vocab.items()}
        service_schemas = {}
        for schema in all_schemas:
            service = schema["service_name"]
            service_schemas[service] = ServiceSchema(schema, service_id=self.get_service_id(service))

        self._service_schemas = service_schemas
        self._schemas = all_schemas
        self._slots_relation_list = {}
示例#9
0
    def __init__(
        self, freq_masks=0, time_masks=0, freq_width=10, time_width=0.1, rng=None, mask_value=0.0,
    ):
        super().__init__()
        # Message to mention that numba specaugment kernel will be available
        # if input device is CUDA and lengths are provided
        logging.debug("Numba SpecAugment kernel is available")

        self.freq_masks = freq_masks
        self.time_masks = time_masks

        self.freq_width = freq_width
        self.time_width = time_width

        self.mask_value = mask_value

        # Unused
        self.rng = rng
        if self.rng is not None:
            logging.warning("`rng` was supplied to SpecAugmentNumba, but it is not used.")

        if isinstance(time_width, int):
            self.adaptive_temporal_width = False
        else:
            if time_width > 1.0 or time_width < 0.0:
                raise ValueError('If `time_width` is a float value, must be in range [0, 1]')

            self.adaptive_temporal_width = True
示例#10
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: Union[str, List[str]],
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
        is_regression_task: bool = False,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath,
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx]))

        audio_tar_filepaths = expand_audio_filepaths(
            audio_tar_filepaths=audio_tar_filepaths,
            shard_strategy=shard_strategy,
            world_size=world_size,
            global_rank=global_rank,
        )
        # Put together WebDataset
        self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info("WebDataset will not shuffle files within the tar files.")

        self._dataset = (
            self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__')
            .to_tuple('audio', 'key')
            .pipe(self._filter)
            .map(f=self._build_sample)
        )
示例#11
0
    def from_config_dict(cls, config: 'DictConfig'):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if _HAS_HYDRA:
            if isinstance(config, DictConfig):
                config = OmegaConf.to_container(config, resolve=True)
                config = OmegaConf.create(config)
                OmegaConf.set_struct(config, True)

            config = maybe_update_config_version(config)

        # Hydra 0.x API
        if ('cls' in config
                or 'target' in config) and 'params' in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        # Hydra 1.x API
        elif '_target_' in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        else:
            instance = None
            imported_cls_tb = None
            # Attempt class path resolution from config `target` class (if it exists)
            if 'target' in config:
                target_cls = config[
                    "target"]  # No guarantee that this is a omegaconf class
                imported_cls = None
                try:
                    # try to import the target class
                    imported_cls = import_class_by_path(target_cls)
                except Exception:
                    imported_cls_tb = traceback.format_exc()

                # try instantiating model with target class
                if imported_cls is not None:
                    # if calling class (cls) is subclass of imported class,
                    # use subclass instead
                    if issubclass(cls, imported_cls):
                        imported_cls = cls

                    try:
                        instance = imported_cls(cfg=config)
                    except Exception:
                        imported_cls_tb = traceback.format_exc()
                        instance = None

            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                if imported_cls_tb is not None:
                    logging.debug(
                        f"Model instantiation from target class {target_cls} failed with following error.\n"
                        f"Falling back to `cls`.\n"
                        f"{imported_cls_tb}")
                instance = cls(cfg=config)

        if not hasattr(instance, '_cfg'):
            instance._cfg = config
        return instance
示例#12
0
    def nemo_topk_check_previous_run(self):
        try:
            self.best_k_models
            self.kth_best_model_path
            self.best_model_score
            self.best_model_path
        except AttributeError:
            raise AttributeError(
                "Lightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update."
            )
        self.best_k_models = {}
        self.kth_best_model_path = ""
        self.best_model_score = None
        self.best_model_path = ""

        checkpoints = list(Path(self.dirpath).rglob("*.ckpt"))
        for checkpoint in checkpoints:
            if self.model_parallel_size is not None and self.model_parallel_size > 1:
                checkpoint = self._uninject_mp_rank(checkpoint)
            checkpoint = str(checkpoint)
            if checkpoint[-10:] == '-last.ckpt':
                continue
            index = checkpoint.find(self.monitor) + len(
                self.monitor) + 1  # Find monitor in str + 1 for '='
            if index != -1:
                match = re.search('[A-z]', checkpoint[index:])
                if match:
                    value = checkpoint[index:index + match.start() -
                                       1]  # -1 due to separator hypen
                    self.best_k_models[checkpoint] = float(value)
        if len(self.best_k_models) < 1:
            return  # No saved checkpoints yet

        _reverse = False if self.mode == "min" else True

        best_k_models = sorted(self.best_k_models,
                               key=self.best_k_models.get,
                               reverse=_reverse)

        ### This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are
        ### instantiated after rank zero. models_to_delete should be 0 for all other ranks.
        if self.model_parallel_size is not None:
            models_to_delete = len(
                best_k_models) - self.model_parallel_size * self.save_top_k
        else:
            models_to_delete = len(best_k_models) - self.save_top_k
        logging.debug(f'Number of models to delete: {models_to_delete}')
        for _ in range(models_to_delete):
            model = best_k_models.pop(-1)
            self.best_k_models.pop(model)
            self._del_model_without_trainer(model)
            logging.debug(f"Removed checkpoint: {model}")

        self.kth_best_model_path = best_k_models[-1]
        self.best_model_path = best_k_models[0]
        self.best_model_score = self.best_k_models[self.best_model_path]
示例#13
0
    def get_features(all_sents,
                     tokenizer,
                     max_seq_length,
                     labels=None,
                     verbose=True):
        """Encode a list of sentences into a list of tuples of (input_ids, segment_ids, input_mask, label)."""
        features = []
        sent_lengths = []
        too_long_count = 0
        for sent_id, sent in enumerate(all_sents):
            if sent_id % 1000 == 0:
                logging.debug(f"Encoding sentence {sent_id}/{len(all_sents)}")
            sent_subtokens = [tokenizer.cls_token]
            for word in sent:
                word_tokens = tokenizer.text_to_tokens(word)
                sent_subtokens.extend(word_tokens)

            if max_seq_length > 0 and len(sent_subtokens) + 1 > max_seq_length:
                sent_subtokens = sent_subtokens[:max_seq_length]
                too_long_count += 1

            sent_subtokens.append(tokenizer.sep_token)
            sent_lengths.append(len(sent_subtokens))

            input_ids = [tokenizer.tokens_to_ids(t) for t in sent_subtokens]

            # The mask has 1 for real tokens and 0 for padding tokens.
            # Only real tokens are attended to.
            input_mask = [1] * len(input_ids)
            segment_ids = [0] * len(input_ids)

            if verbose and sent_id < 2:
                logging.info("*** Example ***")
                logging.info(f"example {sent_id}: {sent}")
                logging.info("subtokens: %s" % " ".join(sent_subtokens))
                logging.info("input_ids: %s" % list2str(input_ids))
                logging.info("segment_ids: %s" % list2str(segment_ids))
                logging.info("input_mask: %s" % list2str(input_mask))
                logging.info("label: %s" %
                             labels[sent_id] if labels else "**Not Provided**")

            label = labels[sent_id] if labels else -1
            features.append([
                np.asarray(input_ids),
                np.asarray(segment_ids),
                np.asarray(input_mask), label
            ])

        if max_seq_length > -1 and too_long_count > 0:
            logging.warning(
                f'Found {too_long_count} out of {len(all_sents)} sentences with more than {max_seq_length} subtokens. '
                f'Truncated long sentences from the end.')
        if verbose:
            get_stats(sent_lengths)
        return features
示例#14
0
def write_predictions_to_file(
    predictions: List[dict],
    input_json_files: List[str],
    output_dir: str,
    schemas: object,
    state_tracker: str,
    eval_debug: bool,
    in_domain_services: set,
):
    """Save predicted dialogues as json files.

    Args:
        predictions: An iterator containing model predictions. This is the output of
            the predict method in the estimator.
        input_json_files: A list of json paths containing the dialogues to run
            inference on.
        output_dir: The directory where output json files will be created.
        schemas: Schemas to all services in the dst dataset
        state_tracker: state tracker option
        eval_debug: output evaluation debugging information
        in_domain_services: in domain services
    """
    logging.info(f"Writing predictions to {output_dir} started.")

    # Index all predictions.
    all_predictions = defaultdict(
        lambda: defaultdict(lambda: defaultdict(dict)))
    for idx, prediction in enumerate(predictions):
        eval_dataset, dialog_id, turn_id, service_name, model_task, slot_intent_id, value_id = prediction[
            'example_id'].split('-')
        all_predictions[(dialog_id, turn_id, service_name)][int(model_task)][
            int(slot_intent_id)][int(value_id)] = prediction
    logging.info(
        f'Predictions for {idx} examples in {eval_dataset} dataset are getting processed.'
    )

    # Read each input file and write its predictions.
    for input_file_path in input_json_files:
        with open(input_file_path, encoding="UTF-8") as f:
            dialogs = json.load(f)
            logging.debug(f'{input_file_path} file is loaded')
            pred_dialogs = []
            for d in dialogs:
                pred_dialog = get_predicted_dialog(d, all_predictions, schemas,
                                                   state_tracker)
                pred_dialogs.append(pred_dialog)
        input_file_name = os.path.basename(input_file_path)
        output_file_path = os.path.join(output_dir, input_file_name)
        with open(output_file_path, "w", encoding="UTF-8") as f:
            json.dump(pred_dialogs,
                      f,
                      indent=2,
                      separators=(",", ": "),
                      sort_keys=True)
示例#15
0
文件: common.py 项目: limberc/NeMo
    def from_config_dict(cls, config: DictConfig):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if isinstance(config, DictConfig):
            config = OmegaConf.to_container(config, resolve=True)
            config = OmegaConf.create(config)
            OmegaConf.set_struct(config, True)

        config = maybe_update_config_version(config)

        # Hydra 0.x API
        if ('cls' in config or 'target' in config) and 'params' in config:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        # Hydra 1.x API
        elif '_target_' in config:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        else:
            instance = None

            # Attempt class path resolution from config `target` class (if it exists)
            if 'target' in config:
                target_cls = config.target
                imported_cls = None
                try:
                    # try to import the target class
                    imported_cls = import_class_by_path(target_cls)
                except (ImportError, ModuleNotFoundError):
                    logging.debug(f'Target class `{target_cls}` could not be imported, falling back to original cls')

                # try instantiating model with target class
                if imported_cls is not None:
                    try:
                        instance = imported_cls(cfg=config)
                    except Exception:
                        imported_cls_tb = traceback.format_exc()
                        logging.debug(
                            f"Model instantiation from target class failed with following error.\n"
                            f"Falling back to `cls`.\n"
                            f"{imported_cls_tb}"
                        )
                        instance = None

            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                instance = cls(cfg=config)

        if not hasattr(instance, '_cfg'):
            instance._cfg = config
        return instance
示例#16
0
    def save_to(self, filename: str, module_names: Optional[List[str]] = None):
        """
        Saves the state of trainable modules in the graph to a checkpoint file.

        Args:
            filename (string): Name of the file where the checkpoint will be saved.
            module_names: List of modules to be frozen (Optional). If passed, all modules will be saved.
        Raises:
            KeyError: If name of the module won't be recognized.
        """
        # Work on all modules.
        if module_names is None:
            module_names = self._modules.keys()

        # Prepare the "graph checkpoint".
        chkpt = {
            "header": {
                "nemo_core_version": nemo_version,
                "name": self.name
            },
            "modules": {}
        }

        log_str = ''
        # Iterate through the modules one by one.
        for name in module_names:
            if name not in self._modules.keys():
                raise KeyError(
                    "Module `{}` not present in the `{}` graph".format(
                        name, self.name))
            # Check module type.
            module = self._modules[name]
            if module.type == ModuleType.trainable:
                # Get module state_dict().
                chkpt["modules"][name] = get_state_dict(module)
                log_str += "  * Module '{}' ({}) params saved \n".format(
                    module.name,
                    type(module).__name__)
            else:
                logging.debug(
                    "Module `{}` is not trainable so cannot be saved".format(
                        name))

        # Save checkpoint.
        save(chkpt, filename)
        log_str = "Saved  the '{}' graph to a checkpoint `{}`:\n".format(
            self.name, filename) + log_str
        logging.info(log_str)
示例#17
0
def generate_vad_frame_pred(vad_model, window_length_in_sec,
                            shift_length_in_sec, manifest_vad_input, out_dir):
    """
    Generate VAD frame level prediction and write to out_dir
    """
    time_unit = int(window_length_in_sec / shift_length_in_sec)
    trunc = int(time_unit / 2)
    trunc_l = time_unit - trunc
    all_len = 0

    data = []
    for line in open(manifest_vad_input, 'r', encoding='utf-8'):
        file = json.loads(line)['audio_filepath'].split("/")[-1]
        data.append(file.split(".wav")[0])
    logging.info(f"Inference on {len(data)} audio files/json lines!")

    status = get_vad_stream_status(data)
    for i, test_batch in enumerate(vad_model.test_dataloader()):
        test_batch = [x.to(vad_model.device) for x in test_batch]
        with autocast():
            log_probs = vad_model(input_signal=test_batch[0],
                                  input_signal_length=test_batch[1])
            probs = torch.softmax(log_probs, dim=-1)
            pred = probs[:, 1]

            if status[i] == 'start':
                to_save = pred[:-trunc]
            elif status[i] == 'next':
                to_save = pred[trunc:-trunc_l]
            elif status[i] == 'end':
                to_save = pred[trunc_l:]
            else:
                to_save = pred

            all_len += len(to_save)
            outpath = os.path.join(out_dir, data[i] + ".frame")
            with open(outpath, "a", encoding='utf-8') as fout:
                for f in range(len(to_save)):
                    fout.write('{0:0.4f}\n'.format(to_save[f]))

        del test_batch
        if status[i] == 'end' or status[i] == 'single':
            logging.debug(
                f"Overall length of prediction of {data[i]} is {all_len}!")
            all_len = 0
    return out_dir
    def setup_loss(self, class_balancing: str = None):
        """Setup loss
           Setup or update loss.

        Args:
            class_balancing: whether to use class weights during training
        """
        if class_balancing not in ['weighted_loss', None]:
            raise ValueError(f'Class balancing {class_balancing} is not supported. Choose from: [null, weighted_loss]')
        if class_balancing == 'weighted_loss' and self.class_weights:
            # you may need to increase the number of epochs for convergence when using weighted_loss
            loss = CrossEntropyLoss(logits_ndim=3, weight=self.class_weights)
            logging.debug(f'Using {class_balancing} class balancing.')
        else:
            loss = CrossEntropyLoss(logits_ndim=3)
            logging.debug(f'Using CrossEntropyLoss class balancing.')
        return loss
示例#19
0
    def _perform_speech_activity_detection(self):
        """
        Checks for type of speech activity detection from config. Choices are NeMo VAD,
        external vad manifest and oracle VAD (generates speech activity labels from provided RTTM files)
        """
        if self.has_vad_model:
            self._dont_auto_split = False
            self._split_duration = 50
            manifest_vad_input = self._diarizer_params.manifest_filepath

            if not self._dont_auto_split:
                logging.info(
                    "Split long audio file to avoid CUDA memory issue")
                logging.debug(
                    "Try smaller split_duration if you still have CUDA memory issue"
                )
                config = {
                    'manifest_filepath': manifest_vad_input,
                    'time_length': self._vad_window_length_in_sec,
                    'split_duration': self._split_duration,
                    'num_workers': self._cfg.num_workers,
                }
                manifest_vad_input = prepare_manifest(config)
            else:
                logging.warning(
                    "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it."
                )

            self._setup_vad_test_data(manifest_vad_input)
            self._run_vad(manifest_vad_input)

        elif self._diarizer_params.vad.external_vad_manifest is not None:
            self._speaker_manifest_path = self._diarizer_params.vad.external_vad_manifest
        elif self._diarizer_params.oracle_vad:
            self._speaker_manifest_path = os.path.join(
                self._speaker_dir, 'oracle_vad_manifest.json')
            self._speaker_manifest_path = write_rttm2manifest(
                self.AUDIO_RTTM_MAP, self._speaker_manifest_path)
        else:
            raise ValueError(
                "Only one of diarizer.oracle_vad, vad.model_path or vad.external_vad_manifest must be passed"
            )
示例#20
0
    def build_vocab_from_csv(self, data_csv_file, col="smiles"):
        """
        Learns vocabulary from a CSV file. Can be called multiple times to update vocabulary.
        """
        logging.debug(
            f"Building vocabulary from CSV col = {col} file = {data_csv_file}")

        # NOTE this has to be run on each CSV file
        if not os.path.exists(data_csv_file):
            raise ValueError(f"Data file: {data_csv_file} is missing")

        df = pd.read_csv(data_csv_file)

        vocab = self.vocab
        for d in df[col]:
            tokens = self.text_to_tokens(d)
            logging.debug(f"Text: {d}, Tokens: {tokens}")
            for token in tokens:
                if token not in vocab:
                    vocab[token] = len(vocab)

        sorted_vocab = sorted(vocab.items(), key=lambda k_v: k_v[1])
        logging.debug(f"Vocab: {sorted_vocab}")

        self.vocab = vocab
        self._update_cache()
    def forward(self, gating_preds, point_outputs_pred, belief_state,
                user_uttr):
        """
        Processes the TRADE model output and updates the dialogue (belief) state with the model's predictions
        Args:
            user_uttr (str): user utterance
            request_state (dict): contains requestsed slots-slot_value pairs for each domain
            belief_state (dict): dialgoue belief state, containt slot-slot value pair for all domains
            gating_preds (float): TRADE model gating predictions
            point_outputs_pred (float): TRADE model pointers predictions
        Returns:
            updated request_state (dict)
            updated belief_state (dict)
        """
        gate_outputs_max, point_outputs_max = self.get_trade_prediction(
            gating_preds, point_outputs_pred)
        trade_output = self.get_human_readable_output(gate_outputs_max,
                                                      point_outputs_max)[0]
        logging.debug('TRADE output: %s', trade_output)

        new_belief_state = self.reformat_belief_state(
            trade_output, copy.deepcopy(belief_state),
            self.data_desc.ontology_value_dict)
        # update request state based on the latest user utterance
        # extract current user output
        new_request_state = self.detect_requestable_slots(
            user_uttr.lower(), self.data_desc.det_dict)
        logging.debug('Belief State after TRADE: %s', belief_state)
        logging.debug('Request State after TRADE: %s', new_request_state)
        return new_belief_state, new_request_state
示例#22
0
    def load_tokenizer(self, base_fname):
        """
        Loads tokenizer's regex (base_fname.model) and vocab (base_fname.vocab) files
        """
        if base_fname.endswith(".model"):
            base_fname = os.path.splitext(base_fname)[0]

        if base_fname:
            self.base_fname = base_fname

        if not self.base_fname:
            raise ValueError(f"base_fname must be specified")

        vocab_file = self.base_fname + '.vocab'
        regex_file = self.base_fname + '.model'

        # load vocab file
        # vocab_file: path to file with vocabulary which consists
        # of characters separated by \n (None/"" for empty vocab)

        logging.debug(f"Loading vocabulary from file = {vocab_file}")
        if os.path.exists(vocab_file):
            vocab = {}
            with open(vocab_file, "r") as f:
                for line in f:
                    line = line.strip()
                    if line:
                        vocab[line] = len(vocab)
            self.vocab = vocab
        else:
            raise RuntimeError(f"Missing vocab_file = {vocab_file}")

        # load regex from a file
        if os.path.exists(regex_file):
            logging.debug(f"Loading regex from file = {regex_file}")
            self.regex = open(regex_file, encoding="utf-8").read().strip()
        else:
            raise RuntimeError(f"Missing regex_file = {regex_file}")

        return self
示例#23
0
def configure_checkpointing(trainer: 'pytorch_lightning.Trainer',
                            log_dir: Path, name: str, params: 'DictConfig'):
    """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    """
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
            )
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"
        )

    # Create the callback and attach it to trainer
    if "filepath" in params:
        if params.filepath is not None:
            logging.warning(
                "filepath is deprecated. Please switch to dirpath and filename instead"
            )
            if params.dirpath is None:
                params.dirpath = Path(params.filepath).parent
            if params.filename is None:
                params.filename = Path(params.filepath).name
        with open_dict(params):
            del params["filepath"]
    if params.dirpath is None:
        params.dirpath = Path(log_dir / 'checkpoints')
    if params.filename is None:
        params.filename = f'{name}--{{{params.monitor}:.2f}}-{{epoch}}'
    if params.prefix is None:
        params.prefix = name
    NeMoModelCheckpoint.CHECKPOINT_NAME_LAST = params.filename + '-last'

    logging.debug(params.dirpath)
    logging.debug(params.filename)
    logging.debug(params.prefix)

    if "val" in params.monitor:
        if (trainer.max_epochs is not None and trainer.max_epochs != -1
                and trainer.max_epochs < trainer.check_val_every_n_epoch):
            logging.error(
                "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
                f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}"
                f"). It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found "
                "in the returned metrics. Please ensure that validation is run within trainer.max_epochs."
            )
        elif trainer.max_steps is not None:
            logging.warning(
                "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to "
                f"{trainer.max_steps}. Please ensure that max_steps will run for at least "
                f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."
            )

    checkpoint_callback = NeMoModelCheckpoint(**params)
    checkpoint_callback.last_model_path = trainer.resume_from_checkpoint or ""
    trainer.callbacks.append(checkpoint_callback)
示例#24
0
    def add_noncategorical_slots(self, state_update: dict,
                                 system_span_boundaries: dict,
                                 user_span_boundaries: dict):
        """Add features for non-categorical slots.
        Args:
            state_update: slot value pairs of state update
            system_span_boundaries: span boundaries of schema description
            user_span_boundaries: span boundaries of utterance 
        """

        noncategorical_slots = self.service_schema.non_categorical_slots
        slot = noncategorical_slots[self.noncategorical_slot_id]

        values = state_update.get(slot, [])
        if not values:
            self.noncategorical_slot_status = STATUS_OFF
        elif values[0] == STR_DONTCARE:
            self.noncategorical_slot_status = STATUS_DONTCARE
        else:
            self.noncategorical_slot_status = STATUS_ACTIVE
            # Add indices of the start and end tokens for the first encountered
            # value. Spans in user utterance are prioritized over the system
            # utterance. If a span is not found, the slot value is ignored.
            if slot in user_span_boundaries:
                start, end = user_span_boundaries[slot]
            elif slot in system_span_boundaries:
                start, end = system_span_boundaries[slot]
            else:
                # A span may not be found because the value was cropped out or because
                # the value was mentioned earlier in the dialogue. Since this model
                # only makes use of the last two utterances to predict state updates,
                # it will fail in such cases.
                logging.debug(
                    f'"Slot values {str(values)} not found in user or system utterance in example with id - {self.example_id}.'
                )
                start = 0
                end = 0
            self.noncategorical_slot_value_start = start
            self.noncategorical_slot_value_end = end
示例#25
0
    def save_tokenizer(self, base_fname=None):
        """
        Saves tokenizer's regex (base_fname.model) and vocab (base_fname.vocab) files
        """
        if base_fname.endswith(".model"):
            base_fname = os.path.splitext(base_fname)[0]

        if base_fname:
            self.base_fname = base_fname

        if not self.base_fname:
            raise ValueError(f"base_fname must be specified")

        vocab_file = self.base_fname + '.vocab'
        regex_file = self.base_fname + '.model'

        logging.debug(f"Saving vocabulary to file = {vocab_file}")
        with open(vocab_file, 'w') as fp:
            for token in self.vocab:
                fp.write(f"{token[0]}\n")

        logging.debug(f"Saving regex to file = {regex_file}")
        open(regex_file, 'w').write(self.regex)
示例#26
0
    def __init__(
        self,
        *,
        manifest_filepath: str,
        labels: List[str],
        feature_loader,
        is_speaker_emb: bool = False,
    ):
        super().__init__()
        self.collection = collections.ASRFeatureSequenceLabel(
            manifests_files=manifest_filepath.split(','), )

        self.feature_loader = feature_loader
        self.labels = labels if labels else self.collection.uniq_labels
        self.is_speaker_emb = is_speaker_emb

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))
def main(cfg: DictConfig) -> None:
    logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}')

    if cfg.pretrained_model is None:
        raise ValueError("A pre-trained model should be provided.")
    _, model = instantiate_model_and_trainer(cfg, ITN_MODEL, False)

    text_file = cfg.inference.from_file
    logging.info(f"Running inference on {text_file}...")
    if not os.path.exists(text_file):
        raise ValueError(f"{text_file} not found.")

    with open(text_file, "r", encoding="utf-8") as f:
        lines = f.readlines()

    batch_size = cfg.inference.get("batch_size", 8)

    batch, all_preds = [], []
    for i, line in enumerate(lines):
        s = spoken_preprocessing(
            line
        )  # this is the same input transformation as in corpus preparation
        batch.append(s.strip())
        if len(batch) == batch_size or i == len(lines) - 1:
            outputs = model._infer(batch)
            for x in outputs:
                all_preds.append(x)
            batch = []
    if len(all_preds) != len(lines):
        raise ValueError(
            "number of input lines and predictions is different: predictions="
            + str(len(all_preds)) + "; lines=" + str(len(lines)))
    out_file = cfg.inference.out_file
    with open(f"{out_file}", "w", encoding="utf-8") as f_out:
        f_out.write("\n".join(all_preds))
    logging.info(f"Predictions saved to {out_file}.")
示例#28
0
    def build_vocab_from_text(self, data_text_file):
        """
        Learns vocabulary from a text file. Can be called multiple times to update vocabulary.
        """
        logging.debug(f"Building vocabulary from TEXT file = {data_text_file}")

        # NOTE this has to be run on each text file
        if not os.path.exists(data_text_file):
            raise ValueError(f"Data file: {data_text_file} is missing")

        vocab = self.vocab
        for d in open(data_text_file, encoding="utf-8").readlines():
            d = d.rstrip()
            tokens = self.text_to_tokens(d)
            logging.debug(f"Text: {d}, Tokens: {d}")
            for token in tokens:
                if token not in vocab:
                    vocab[token] = len(vocab)

        sorted_vocab = sorted(vocab.items(), key=lambda k_v: k_v[1])
        logging.debug(f"Vocab: {sorted_vocab}")

        self.vocab = vocab
        self._update_cache()
示例#29
0
    def get_features(
        queries,
        max_seq_length,
        tokenizer,
        pad_label=128,
        word_level_slots=None,
        ignore_extra_tokens=False,
        ignore_start_end=False,
    ):
        """
        Convert queries (utterance, intent label and slot labels) to BERT input format 
        """

        all_subtokens = []
        all_loss_mask = []
        all_subtokens_mask = []
        all_segment_ids = []
        all_input_ids = []
        all_input_mask = []
        sent_lengths = []
        all_slots = []

        with_label = word_level_slots is not None

        for i, query in enumerate(queries):
            words = query.strip().split()
            subtokens = [tokenizer.cls_token]
            loss_mask = [1 - ignore_start_end]
            subtokens_mask = [0]
            if with_label:
                slots = [pad_label]

            for j, word in enumerate(words):
                word_tokens = tokenizer.text_to_tokens(word)

                # to handle emojis that could be neglected during tokenization
                if len(word.strip()) > 0 and len(word_tokens) == 0:
                    word_tokens = [tokenizer.ids_to_tokens(tokenizer.unk_id)]

                subtokens.extend(word_tokens)
                # mask all sub-word tokens except the first token in a word
                # use the label for the first sub-word token as the label for the entire word to eliminate need for disambiguation
                loss_mask.append(1)
                loss_mask.extend([int(not ignore_extra_tokens)] *
                                 (len(word_tokens) - 1))

                subtokens_mask.append(1)
                subtokens_mask.extend([0] * (len(word_tokens) - 1))

                if with_label:
                    slots.extend([word_level_slots[i][j]] * len(word_tokens))

            subtokens.append(tokenizer.sep_token)
            loss_mask.append(1 - ignore_start_end)
            subtokens_mask.append(0)
            sent_lengths.append(len(subtokens))
            all_subtokens.append(subtokens)
            all_loss_mask.append(loss_mask)
            all_subtokens_mask.append(subtokens_mask)
            all_input_mask.append([1] * len(subtokens))
            if with_label:
                slots.append(pad_label)
                all_slots.append(slots)
        max_seq_length_data = max(sent_lengths)
        max_seq_length = min(
            max_seq_length,
            max_seq_length_data) if max_seq_length > 0 else max_seq_length_data
        logging.info(f'Setting max length to: {max_seq_length}')
        get_stats(sent_lengths)

        # truncate and pad samples
        (
            all_slots,
            all_subtokens,
            all_input_mask,
            all_loss_mask,
            all_subtokens_mask,
            all_input_ids,
            all_segment_ids,
        ) = DialogueBERTDataset.truncate_and_pad(
            max_seq_length,
            ignore_start_end,
            with_label,
            pad_label,
            tokenizer,
            all_slots,
            all_subtokens,
            all_input_mask,
            all_loss_mask,
            all_subtokens_mask,
            all_input_ids,
            all_segment_ids,
        )

        # log examples for debugging
        logging.debug("*** Some Examples of Processed Data ***")
        for i in range(min(len(all_input_ids), 5)):
            logging.debug("i: %s" % (i))
            logging.debug("subtokens: %s" %
                          " ".join(list(map(str, all_subtokens[i]))))
            logging.debug("loss_mask: %s" %
                          " ".join(list(map(str, all_loss_mask[i]))))
            logging.debug("input_mask: %s" %
                          " ".join(list(map(str, all_input_mask[i]))))
            logging.debug("subtokens_mask: %s" %
                          " ".join(list(map(str, all_subtokens_mask[i]))))
            if with_label:
                logging.debug("slots_label: %s" %
                              " ".join(list(map(str, all_slots[i]))))

        return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask,
                all_subtokens_mask, all_slots)
示例#30
0
    def __init__(
        self,
        *,
        audio_tar_filepaths: Union[str, List[str]],
        manifest_filepath: str,
        labels: List[str],
        featurizer,
        shuffle_n: int = 0,
        min_duration: Optional[float] = 0.1,
        max_duration: Optional[float] = None,
        trim: bool = False,
        load_audio: bool = True,
        shard_strategy: str = "scatter",
        global_rank: int = 0,
        world_size: int = 0,
    ):
        self.collection = collections.ASRSpeechLabel(
            manifests_files=manifest_filepath.split(','),
            min_duration=min_duration,
            max_duration=max_duration,
            index_by_file_id=
            True,  # Must set this so the manifest lines can be indexed by file ID
        )

        self.file_occurence = count_occurence(self.collection.mapping)

        self.featurizer = featurizer
        self.trim = trim
        self.load_audio = load_audio

        self.labels = labels if labels else self.collection.uniq_labels
        self.num_classes = len(self.labels)

        self.label2id, self.id2label = {}, {}
        for label_id, label in enumerate(self.labels):
            self.label2id[label] = label_id
            self.id2label[label_id] = label

        for idx in range(len(self.labels[:5])):
            logging.debug(" label id {} and its mapped label {}".format(
                idx, self.id2label[idx]))

        valid_shard_strategies = ['scatter', 'replicate']
        if shard_strategy not in valid_shard_strategies:
            raise ValueError(
                f"`shard_strategy` must be one of {valid_shard_strategies}")

        if isinstance(audio_tar_filepaths, str):
            # Replace '(' and '[' with '{'
            brace_keys_open = ['(', '[', '<', '_OP_']
            for bkey in brace_keys_open:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "{")

            # Replace ')' and ']' with '}'
            brace_keys_close = [')', ']', '>', '_CL_']
            for bkey in brace_keys_close:
                if bkey in audio_tar_filepaths:
                    audio_tar_filepaths = audio_tar_filepaths.replace(
                        bkey, "}")

        # Check for distributed and partition shards accordingly
        if world_size > 1:
            if isinstance(audio_tar_filepaths, str):
                # Brace expand
                audio_tar_filepaths = list(
                    braceexpand.braceexpand(audio_tar_filepaths))

            if shard_strategy == 'scatter':
                logging.info(
                    "All tarred dataset shards will be scattered evenly across all nodes."
                )

                if len(audio_tar_filepaths) % world_size != 0:
                    logging.warning(
                        f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible "
                        f"by number of distributed workers ({world_size}).")

                begin_idx = (len(audio_tar_filepaths) //
                             world_size) * global_rank
                end_idx = begin_idx + (len(audio_tar_filepaths) // world_size)
                audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx]
                logging.info(
                    "Partitioning tarred dataset: process (%d) taking shards [%d, %d)",
                    global_rank, begin_idx, end_idx)

            elif shard_strategy == 'replicate':
                logging.info(
                    "All tarred dataset shards will be replicated across all nodes."
                )

            else:
                raise ValueError(
                    f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}"
                )

        # Put together WebDataset
        self._dataset = wd.WebDataset(audio_tar_filepaths)

        if shuffle_n > 0:
            self._dataset = self._dataset.shuffle(shuffle_n)
        else:
            logging.info(
                "WebDataset will not shuffle files within the tar files.")

        self._dataset = (self._dataset.rename(
            audio='wav', key='__key__').to_tuple('audio', 'key').pipe(
                self._filter).map(f=self._build_sample))