def bind(self, tensors_ref: List["NmTensor"], port_names: Optional[str] = None): """ Binds the "default" outputs. Args: tensors_ref: List of tensors to be added. port_names: List of port names (visible outside). If None: using internal tensor "output port names". """ # Set names. if port_names is None: port_names = [tensor.name for tensor in tensors_ref] for name, tensor in zip(port_names, tensors_ref): # Check the presence of the port name in "default" dictionary. if name in self._default_outputs.keys(): # Name present - use the name being combination of producer and port names. name = (str(tensor.producer_step_number) + "_" + tensor.producer_name + "_" + tensor.name ) # last = port name logging.debug( "Setting unique name of the default output port `{}` produced in step {} by `{}` to `{}`" .format(tensor.name, tensor.producer_step_number, tensor.producer_name, name)) # Store the output. self._default_outputs[name] = GraphOutput( tensor.ntype, tensor.producer_step_module_port)
def forward(self, user_uttr, dialog_history): """ Returns dialogue utterances in the format accepted by the TRADE Dialogue state tracking model Args: dialog_history (list): dialogue history, list of system and diaglogue utterances user_uttr (str): user utterance Returns: dialog_ids (int): token ids for the whole dialogue history dialog_lens (int): length of the whole tokenized dialogue history dialog_history (list): updated dialogue history, list of system and diaglogue utterances """ # TODO: why we update sys utterance, whereas we have only user utterance at that point? dialog_history.append(["user", user_uttr]) logging.debug("Dialogue history: %s", dialog_history) context = ' ; '.join( [item[1].strip().lower() for item in dialog_history]).strip() + ' ;' context_ids = self.data_desc.vocab.tokens2ids(context.split()) dialog_ids = torch.tensor(context_ids).unsqueeze(0).to(self._device) dialog_lens = torch.tensor(len(context_ids)).unsqueeze(0).to( self._device) # logging.debug("!! dialog_ids: %s", dialog_ids) # logging.debug("!! dialog_lens: %s", dialog_lens) return dialog_ids, dialog_lens, dialog_history
def __init__( self, *, manifest_filepath: str, labels: List[str], featurizer, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, time_length: Optional[float] = 8, shift_length: Optional[float] = 1, normalize_audio: bool = False, is_regression_task: bool = False, ): self.time_length = time_length self.shift_length = shift_length self.normalize_audio = normalize_audio logging.debug("Time length considered for collate func is {}".format(self.time_length)) logging.debug("Shift length considered for collate func is {}".format(self.shift_length)) super().__init__( manifest_filepath=manifest_filepath, labels=labels, featurizer=featurizer, min_duration=min_duration, max_duration=max_duration, trim=trim, is_regression_task=is_regression_task, )
def get_dataset_as_dict(file_path_patterns) -> dict: """Read the DSTC8/SGD json dialogue data as dictionary with dialog ID as keys. Args: file_path_patterns: list or directory of files Returns: dataset_dict: dataset dictionary with dialog ID as keys """ dataset_dict = {} if isinstance(file_path_patterns, list): list_fp = file_path_patterns else: list_fp = sorted(glob.glob(file_path_patterns)) for fp in list_fp: if PER_FRAME_OUTPUT_FILENAME in fp: continue logging.debug("Loading file: %s", fp) with open(fp, encoding="UTF-8") as f: data = json.load(f) if isinstance(data, list): for dial in data: dataset_dict[dial["dialogue_id"]] = dial elif isinstance(data, dict): dataset_dict.update(data) f.close() return dataset_dict
def __init__( self, *, manifest_filepath: str, labels: List[str], featurizer, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, ): super().__init__() self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, ) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx]))
def __init__( self, always_save_nemo=False, save_nemo_on_train_end=True, save_best_model=False, postfix=".nemo", n_resume=False, model_parallel_size=None, **kwargs, ): # Parse and store "extended" parameters: save_best model and postfix. self.always_save_nemo = always_save_nemo self.save_nemo_on_train_end = save_nemo_on_train_end self.save_best_model = save_best_model if self.save_best_model and not self.save_nemo_on_train_end: logging.warning(( "Found save_best_model is True and save_nemo_on_train_end is False. " "Set save_nemo_on_train_end to True to automatically save the best model." )) self.postfix = postfix self.previous_best_path = "" self.model_parallel_size = model_parallel_size # `prefix` is deprecated if 'prefix' in kwargs: self.prefix = kwargs.pop('prefix') else: self.prefix = "" # Call the parent class constructor with the remaining kwargs. super().__init__(**kwargs) if self.save_top_k != -1 and n_resume: logging.debug("Checking previous runs") self.nemo_topk_check_previous_run()
def unfreeze(self, module_names: Optional[List[str]] = None): """ Unfreezes weights of the trainable modules in a graph. Args: module_names: List of modules to be unfrozen (Optional). If not passed, all modules will be unfrozen. Raises: KeyError: If name of the module won't be recognized. """ # Work on all modules. if module_names is None: module_names = self._modules.keys() # Iterate through modules one by one. for name in module_names: if name not in self._modules.keys(): raise KeyError( "Module `{}` not present in the `{}` graph".format( name, self.name)) # Check module type. module = self._modules[name] if module.type == ModuleType.trainable: # Unfreeze weights of the module. module.unfreeze() else: logging.debug( "Module `{}` is not trainable so cannot be unfrozen". format(name))
def __init__(self, schema_json_paths: Union[str, List[str]]): """ schema_json_paths: list of .json path to schema files of a single str with path to the json file. """ # Load the schema from the json file. if isinstance(schema_json_paths, str): with open(schema_json_paths, "r") as f: all_schemas = json.load(f) f.close() else: # load multiple schemas from the list of the json files all_schemas = [] completed_services = [] for schema_json_path in schema_json_paths: with open(schema_json_path, "r") as f: schemas = json.load(f) f.close() logging.debug("Num of services in %s: %s", schema_json_path, len(schemas)) for service in schemas: if service['service_name'] not in completed_services: completed_services.append(service['service_name']) all_schemas.append(service) self._services = sorted(schema["service_name"] for schema in all_schemas) self._services_vocab = {v: k for k, v in enumerate(self._services)} self._services_id_to_vocab = {v: k for k, v in self._services_vocab.items()} service_schemas = {} for schema in all_schemas: service = schema["service_name"] service_schemas[service] = ServiceSchema(schema, service_id=self.get_service_id(service)) self._service_schemas = service_schemas self._schemas = all_schemas self._slots_relation_list = {}
def __init__( self, freq_masks=0, time_masks=0, freq_width=10, time_width=0.1, rng=None, mask_value=0.0, ): super().__init__() # Message to mention that numba specaugment kernel will be available # if input device is CUDA and lengths are provided logging.debug("Numba SpecAugment kernel is available") self.freq_masks = freq_masks self.time_masks = time_masks self.freq_width = freq_width self.time_width = time_width self.mask_value = mask_value # Unused self.rng = rng if self.rng is not None: logging.warning("`rng` was supplied to SpecAugmentNumba, but it is not used.") if isinstance(time_width, int): self.adaptive_temporal_width = False else: if time_width > 1.0 or time_width < 0.0: raise ValueError('If `time_width` is a float value, must be in range [0, 1]') self.adaptive_temporal_width = True
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: Union[str, List[str]], labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, is_regression_task: bool = False, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath, min_duration=min_duration, max_duration=max_duration, index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format(idx, self.id2label[idx])) audio_tar_filepaths = expand_audio_filepaths( audio_tar_filepaths=audio_tar_filepaths, shard_strategy=shard_strategy, world_size=world_size, global_rank=global_rank, ) # Put together WebDataset self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info("WebDataset will not shuffle files within the tar files.") self._dataset = ( self._dataset.rename(audio=VALID_FILE_FORMATS, key='__key__') .to_tuple('audio', 'key') .pipe(self._filter) .map(f=self._build_sample) )
def from_config_dict(cls, config: 'DictConfig'): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if _HAS_HYDRA: if isinstance(config, DictConfig): config = OmegaConf.to_container(config, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) config = maybe_update_config_version(config) # Hydra 0.x API if ('cls' in config or 'target' in config) and 'params' in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) # Hydra 1.x API elif '_target_' in config and _HAS_HYDRA: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) else: instance = None imported_cls_tb = None # Attempt class path resolution from config `target` class (if it exists) if 'target' in config: target_cls = config[ "target"] # No guarantee that this is a omegaconf class imported_cls = None try: # try to import the target class imported_cls = import_class_by_path(target_cls) except Exception: imported_cls_tb = traceback.format_exc() # try instantiating model with target class if imported_cls is not None: # if calling class (cls) is subclass of imported class, # use subclass instead if issubclass(cls, imported_cls): imported_cls = cls try: instance = imported_cls(cfg=config) except Exception: imported_cls_tb = traceback.format_exc() instance = None # target class resolution was unsuccessful, fall back to current `cls` if instance is None: if imported_cls_tb is not None: logging.debug( f"Model instantiation from target class {target_cls} failed with following error.\n" f"Falling back to `cls`.\n" f"{imported_cls_tb}") instance = cls(cfg=config) if not hasattr(instance, '_cfg'): instance._cfg = config return instance
def nemo_topk_check_previous_run(self): try: self.best_k_models self.kth_best_model_path self.best_model_score self.best_model_path except AttributeError: raise AttributeError( "Lightning's ModelCheckpoint was updated. NeMoModelCheckpoint will need an update." ) self.best_k_models = {} self.kth_best_model_path = "" self.best_model_score = None self.best_model_path = "" checkpoints = list(Path(self.dirpath).rglob("*.ckpt")) for checkpoint in checkpoints: if self.model_parallel_size is not None and self.model_parallel_size > 1: checkpoint = self._uninject_mp_rank(checkpoint) checkpoint = str(checkpoint) if checkpoint[-10:] == '-last.ckpt': continue index = checkpoint.find(self.monitor) + len( self.monitor) + 1 # Find monitor in str + 1 for '=' if index != -1: match = re.search('[A-z]', checkpoint[index:]) if match: value = checkpoint[index:index + match.start() - 1] # -1 due to separator hypen self.best_k_models[checkpoint] = float(value) if len(self.best_k_models) < 1: return # No saved checkpoints yet _reverse = False if self.mode == "min" else True best_k_models = sorted(self.best_k_models, key=self.best_k_models.get, reverse=_reverse) ### This section should be ok as rank zero will delete all excess checkpoints, since all other ranks are ### instantiated after rank zero. models_to_delete should be 0 for all other ranks. if self.model_parallel_size is not None: models_to_delete = len( best_k_models) - self.model_parallel_size * self.save_top_k else: models_to_delete = len(best_k_models) - self.save_top_k logging.debug(f'Number of models to delete: {models_to_delete}') for _ in range(models_to_delete): model = best_k_models.pop(-1) self.best_k_models.pop(model) self._del_model_without_trainer(model) logging.debug(f"Removed checkpoint: {model}") self.kth_best_model_path = best_k_models[-1] self.best_model_path = best_k_models[0] self.best_model_score = self.best_k_models[self.best_model_path]
def get_features(all_sents, tokenizer, max_seq_length, labels=None, verbose=True): """Encode a list of sentences into a list of tuples of (input_ids, segment_ids, input_mask, label).""" features = [] sent_lengths = [] too_long_count = 0 for sent_id, sent in enumerate(all_sents): if sent_id % 1000 == 0: logging.debug(f"Encoding sentence {sent_id}/{len(all_sents)}") sent_subtokens = [tokenizer.cls_token] for word in sent: word_tokens = tokenizer.text_to_tokens(word) sent_subtokens.extend(word_tokens) if max_seq_length > 0 and len(sent_subtokens) + 1 > max_seq_length: sent_subtokens = sent_subtokens[:max_seq_length] too_long_count += 1 sent_subtokens.append(tokenizer.sep_token) sent_lengths.append(len(sent_subtokens)) input_ids = [tokenizer.tokens_to_ids(t) for t in sent_subtokens] # The mask has 1 for real tokens and 0 for padding tokens. # Only real tokens are attended to. input_mask = [1] * len(input_ids) segment_ids = [0] * len(input_ids) if verbose and sent_id < 2: logging.info("*** Example ***") logging.info(f"example {sent_id}: {sent}") logging.info("subtokens: %s" % " ".join(sent_subtokens)) logging.info("input_ids: %s" % list2str(input_ids)) logging.info("segment_ids: %s" % list2str(segment_ids)) logging.info("input_mask: %s" % list2str(input_mask)) logging.info("label: %s" % labels[sent_id] if labels else "**Not Provided**") label = labels[sent_id] if labels else -1 features.append([ np.asarray(input_ids), np.asarray(segment_ids), np.asarray(input_mask), label ]) if max_seq_length > -1 and too_long_count > 0: logging.warning( f'Found {too_long_count} out of {len(all_sents)} sentences with more than {max_seq_length} subtokens. ' f'Truncated long sentences from the end.') if verbose: get_stats(sent_lengths) return features
def write_predictions_to_file( predictions: List[dict], input_json_files: List[str], output_dir: str, schemas: object, state_tracker: str, eval_debug: bool, in_domain_services: set, ): """Save predicted dialogues as json files. Args: predictions: An iterator containing model predictions. This is the output of the predict method in the estimator. input_json_files: A list of json paths containing the dialogues to run inference on. output_dir: The directory where output json files will be created. schemas: Schemas to all services in the dst dataset state_tracker: state tracker option eval_debug: output evaluation debugging information in_domain_services: in domain services """ logging.info(f"Writing predictions to {output_dir} started.") # Index all predictions. all_predictions = defaultdict( lambda: defaultdict(lambda: defaultdict(dict))) for idx, prediction in enumerate(predictions): eval_dataset, dialog_id, turn_id, service_name, model_task, slot_intent_id, value_id = prediction[ 'example_id'].split('-') all_predictions[(dialog_id, turn_id, service_name)][int(model_task)][ int(slot_intent_id)][int(value_id)] = prediction logging.info( f'Predictions for {idx} examples in {eval_dataset} dataset are getting processed.' ) # Read each input file and write its predictions. for input_file_path in input_json_files: with open(input_file_path, encoding="UTF-8") as f: dialogs = json.load(f) logging.debug(f'{input_file_path} file is loaded') pred_dialogs = [] for d in dialogs: pred_dialog = get_predicted_dialog(d, all_predictions, schemas, state_tracker) pred_dialogs.append(pred_dialog) input_file_name = os.path.basename(input_file_path) output_file_path = os.path.join(output_dir, input_file_name) with open(output_file_path, "w", encoding="UTF-8") as f: json.dump(pred_dialogs, f, indent=2, separators=(",", ": "), sort_keys=True)
def from_config_dict(cls, config: DictConfig): """Instantiates object using DictConfig-based configuration""" # Resolve the config dict if isinstance(config, DictConfig): config = OmegaConf.to_container(config, resolve=True) config = OmegaConf.create(config) OmegaConf.set_struct(config, True) config = maybe_update_config_version(config) # Hydra 0.x API if ('cls' in config or 'target' in config) and 'params' in config: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) # Hydra 1.x API elif '_target_' in config: # regular hydra-based instantiation instance = hydra.utils.instantiate(config=config) else: instance = None # Attempt class path resolution from config `target` class (if it exists) if 'target' in config: target_cls = config.target imported_cls = None try: # try to import the target class imported_cls = import_class_by_path(target_cls) except (ImportError, ModuleNotFoundError): logging.debug(f'Target class `{target_cls}` could not be imported, falling back to original cls') # try instantiating model with target class if imported_cls is not None: try: instance = imported_cls(cfg=config) except Exception: imported_cls_tb = traceback.format_exc() logging.debug( f"Model instantiation from target class failed with following error.\n" f"Falling back to `cls`.\n" f"{imported_cls_tb}" ) instance = None # target class resolution was unsuccessful, fall back to current `cls` if instance is None: instance = cls(cfg=config) if not hasattr(instance, '_cfg'): instance._cfg = config return instance
def save_to(self, filename: str, module_names: Optional[List[str]] = None): """ Saves the state of trainable modules in the graph to a checkpoint file. Args: filename (string): Name of the file where the checkpoint will be saved. module_names: List of modules to be frozen (Optional). If passed, all modules will be saved. Raises: KeyError: If name of the module won't be recognized. """ # Work on all modules. if module_names is None: module_names = self._modules.keys() # Prepare the "graph checkpoint". chkpt = { "header": { "nemo_core_version": nemo_version, "name": self.name }, "modules": {} } log_str = '' # Iterate through the modules one by one. for name in module_names: if name not in self._modules.keys(): raise KeyError( "Module `{}` not present in the `{}` graph".format( name, self.name)) # Check module type. module = self._modules[name] if module.type == ModuleType.trainable: # Get module state_dict(). chkpt["modules"][name] = get_state_dict(module) log_str += " * Module '{}' ({}) params saved \n".format( module.name, type(module).__name__) else: logging.debug( "Module `{}` is not trainable so cannot be saved".format( name)) # Save checkpoint. save(chkpt, filename) log_str = "Saved the '{}' graph to a checkpoint `{}`:\n".format( self.name, filename) + log_str logging.info(log_str)
def generate_vad_frame_pred(vad_model, window_length_in_sec, shift_length_in_sec, manifest_vad_input, out_dir): """ Generate VAD frame level prediction and write to out_dir """ time_unit = int(window_length_in_sec / shift_length_in_sec) trunc = int(time_unit / 2) trunc_l = time_unit - trunc all_len = 0 data = [] for line in open(manifest_vad_input, 'r', encoding='utf-8'): file = json.loads(line)['audio_filepath'].split("/")[-1] data.append(file.split(".wav")[0]) logging.info(f"Inference on {len(data)} audio files/json lines!") status = get_vad_stream_status(data) for i, test_batch in enumerate(vad_model.test_dataloader()): test_batch = [x.to(vad_model.device) for x in test_batch] with autocast(): log_probs = vad_model(input_signal=test_batch[0], input_signal_length=test_batch[1]) probs = torch.softmax(log_probs, dim=-1) pred = probs[:, 1] if status[i] == 'start': to_save = pred[:-trunc] elif status[i] == 'next': to_save = pred[trunc:-trunc_l] elif status[i] == 'end': to_save = pred[trunc_l:] else: to_save = pred all_len += len(to_save) outpath = os.path.join(out_dir, data[i] + ".frame") with open(outpath, "a", encoding='utf-8') as fout: for f in range(len(to_save)): fout.write('{0:0.4f}\n'.format(to_save[f])) del test_batch if status[i] == 'end' or status[i] == 'single': logging.debug( f"Overall length of prediction of {data[i]} is {all_len}!") all_len = 0 return out_dir
def setup_loss(self, class_balancing: str = None): """Setup loss Setup or update loss. Args: class_balancing: whether to use class weights during training """ if class_balancing not in ['weighted_loss', None]: raise ValueError(f'Class balancing {class_balancing} is not supported. Choose from: [null, weighted_loss]') if class_balancing == 'weighted_loss' and self.class_weights: # you may need to increase the number of epochs for convergence when using weighted_loss loss = CrossEntropyLoss(logits_ndim=3, weight=self.class_weights) logging.debug(f'Using {class_balancing} class balancing.') else: loss = CrossEntropyLoss(logits_ndim=3) logging.debug(f'Using CrossEntropyLoss class balancing.') return loss
def _perform_speech_activity_detection(self): """ Checks for type of speech activity detection from config. Choices are NeMo VAD, external vad manifest and oracle VAD (generates speech activity labels from provided RTTM files) """ if self.has_vad_model: self._dont_auto_split = False self._split_duration = 50 manifest_vad_input = self._diarizer_params.manifest_filepath if not self._dont_auto_split: logging.info( "Split long audio file to avoid CUDA memory issue") logging.debug( "Try smaller split_duration if you still have CUDA memory issue" ) config = { 'manifest_filepath': manifest_vad_input, 'time_length': self._vad_window_length_in_sec, 'split_duration': self._split_duration, 'num_workers': self._cfg.num_workers, } manifest_vad_input = prepare_manifest(config) else: logging.warning( "If you encounter CUDA memory issue, try splitting manifest entry by split_duration to avoid it." ) self._setup_vad_test_data(manifest_vad_input) self._run_vad(manifest_vad_input) elif self._diarizer_params.vad.external_vad_manifest is not None: self._speaker_manifest_path = self._diarizer_params.vad.external_vad_manifest elif self._diarizer_params.oracle_vad: self._speaker_manifest_path = os.path.join( self._speaker_dir, 'oracle_vad_manifest.json') self._speaker_manifest_path = write_rttm2manifest( self.AUDIO_RTTM_MAP, self._speaker_manifest_path) else: raise ValueError( "Only one of diarizer.oracle_vad, vad.model_path or vad.external_vad_manifest must be passed" )
def build_vocab_from_csv(self, data_csv_file, col="smiles"): """ Learns vocabulary from a CSV file. Can be called multiple times to update vocabulary. """ logging.debug( f"Building vocabulary from CSV col = {col} file = {data_csv_file}") # NOTE this has to be run on each CSV file if not os.path.exists(data_csv_file): raise ValueError(f"Data file: {data_csv_file} is missing") df = pd.read_csv(data_csv_file) vocab = self.vocab for d in df[col]: tokens = self.text_to_tokens(d) logging.debug(f"Text: {d}, Tokens: {tokens}") for token in tokens: if token not in vocab: vocab[token] = len(vocab) sorted_vocab = sorted(vocab.items(), key=lambda k_v: k_v[1]) logging.debug(f"Vocab: {sorted_vocab}") self.vocab = vocab self._update_cache()
def forward(self, gating_preds, point_outputs_pred, belief_state, user_uttr): """ Processes the TRADE model output and updates the dialogue (belief) state with the model's predictions Args: user_uttr (str): user utterance request_state (dict): contains requestsed slots-slot_value pairs for each domain belief_state (dict): dialgoue belief state, containt slot-slot value pair for all domains gating_preds (float): TRADE model gating predictions point_outputs_pred (float): TRADE model pointers predictions Returns: updated request_state (dict) updated belief_state (dict) """ gate_outputs_max, point_outputs_max = self.get_trade_prediction( gating_preds, point_outputs_pred) trade_output = self.get_human_readable_output(gate_outputs_max, point_outputs_max)[0] logging.debug('TRADE output: %s', trade_output) new_belief_state = self.reformat_belief_state( trade_output, copy.deepcopy(belief_state), self.data_desc.ontology_value_dict) # update request state based on the latest user utterance # extract current user output new_request_state = self.detect_requestable_slots( user_uttr.lower(), self.data_desc.det_dict) logging.debug('Belief State after TRADE: %s', belief_state) logging.debug('Request State after TRADE: %s', new_request_state) return new_belief_state, new_request_state
def load_tokenizer(self, base_fname): """ Loads tokenizer's regex (base_fname.model) and vocab (base_fname.vocab) files """ if base_fname.endswith(".model"): base_fname = os.path.splitext(base_fname)[0] if base_fname: self.base_fname = base_fname if not self.base_fname: raise ValueError(f"base_fname must be specified") vocab_file = self.base_fname + '.vocab' regex_file = self.base_fname + '.model' # load vocab file # vocab_file: path to file with vocabulary which consists # of characters separated by \n (None/"" for empty vocab) logging.debug(f"Loading vocabulary from file = {vocab_file}") if os.path.exists(vocab_file): vocab = {} with open(vocab_file, "r") as f: for line in f: line = line.strip() if line: vocab[line] = len(vocab) self.vocab = vocab else: raise RuntimeError(f"Missing vocab_file = {vocab_file}") # load regex from a file if os.path.exists(regex_file): logging.debug(f"Loading regex from file = {regex_file}") self.regex = open(regex_file, encoding="utf-8").read().strip() else: raise RuntimeError(f"Missing regex_file = {regex_file}") return self
def configure_checkpointing(trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: 'DictConfig'): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ for callback in trainer.callbacks: if isinstance(callback, ModelCheckpoint): raise CheckpointMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " "to False, or remove ModelCheckpoint from the lightning trainer" ) if Path(trainer.weights_save_path) != Path.cwd(): raise CheckpointMisconfigurationError( "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" ) # Create the callback and attach it to trainer if "filepath" in params: if params.filepath is not None: logging.warning( "filepath is deprecated. Please switch to dirpath and filename instead" ) if params.dirpath is None: params.dirpath = Path(params.filepath).parent if params.filename is None: params.filename = Path(params.filepath).name with open_dict(params): del params["filepath"] if params.dirpath is None: params.dirpath = Path(log_dir / 'checkpoints') if params.filename is None: params.filename = f'{name}--{{{params.monitor}:.2f}}-{{epoch}}' if params.prefix is None: params.prefix = name NeMoModelCheckpoint.CHECKPOINT_NAME_LAST = params.filename + '-last' logging.debug(params.dirpath) logging.debug(params.filename) logging.debug(params.prefix) if "val" in params.monitor: if (trainer.max_epochs is not None and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch): logging.error( "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}" f"). It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found " "in the returned metrics. Please ensure that validation is run within trainer.max_epochs." ) elif trainer.max_steps is not None: logging.warning( "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " f"{trainer.max_steps}. Please ensure that max_steps will run for at least " f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." ) checkpoint_callback = NeMoModelCheckpoint(**params) checkpoint_callback.last_model_path = trainer.resume_from_checkpoint or "" trainer.callbacks.append(checkpoint_callback)
def add_noncategorical_slots(self, state_update: dict, system_span_boundaries: dict, user_span_boundaries: dict): """Add features for non-categorical slots. Args: state_update: slot value pairs of state update system_span_boundaries: span boundaries of schema description user_span_boundaries: span boundaries of utterance """ noncategorical_slots = self.service_schema.non_categorical_slots slot = noncategorical_slots[self.noncategorical_slot_id] values = state_update.get(slot, []) if not values: self.noncategorical_slot_status = STATUS_OFF elif values[0] == STR_DONTCARE: self.noncategorical_slot_status = STATUS_DONTCARE else: self.noncategorical_slot_status = STATUS_ACTIVE # Add indices of the start and end tokens for the first encountered # value. Spans in user utterance are prioritized over the system # utterance. If a span is not found, the slot value is ignored. if slot in user_span_boundaries: start, end = user_span_boundaries[slot] elif slot in system_span_boundaries: start, end = system_span_boundaries[slot] else: # A span may not be found because the value was cropped out or because # the value was mentioned earlier in the dialogue. Since this model # only makes use of the last two utterances to predict state updates, # it will fail in such cases. logging.debug( f'"Slot values {str(values)} not found in user or system utterance in example with id - {self.example_id}.' ) start = 0 end = 0 self.noncategorical_slot_value_start = start self.noncategorical_slot_value_end = end
def save_tokenizer(self, base_fname=None): """ Saves tokenizer's regex (base_fname.model) and vocab (base_fname.vocab) files """ if base_fname.endswith(".model"): base_fname = os.path.splitext(base_fname)[0] if base_fname: self.base_fname = base_fname if not self.base_fname: raise ValueError(f"base_fname must be specified") vocab_file = self.base_fname + '.vocab' regex_file = self.base_fname + '.model' logging.debug(f"Saving vocabulary to file = {vocab_file}") with open(vocab_file, 'w') as fp: for token in self.vocab: fp.write(f"{token[0]}\n") logging.debug(f"Saving regex to file = {regex_file}") open(regex_file, 'w').write(self.regex)
def __init__( self, *, manifest_filepath: str, labels: List[str], feature_loader, is_speaker_emb: bool = False, ): super().__init__() self.collection = collections.ASRFeatureSequenceLabel( manifests_files=manifest_filepath.split(','), ) self.feature_loader = feature_loader self.labels = labels if labels else self.collection.uniq_labels self.is_speaker_emb = is_speaker_emb self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx]))
def main(cfg: DictConfig) -> None: logging.debug(f'Config Params: {OmegaConf.to_yaml(cfg)}') if cfg.pretrained_model is None: raise ValueError("A pre-trained model should be provided.") _, model = instantiate_model_and_trainer(cfg, ITN_MODEL, False) text_file = cfg.inference.from_file logging.info(f"Running inference on {text_file}...") if not os.path.exists(text_file): raise ValueError(f"{text_file} not found.") with open(text_file, "r", encoding="utf-8") as f: lines = f.readlines() batch_size = cfg.inference.get("batch_size", 8) batch, all_preds = [], [] for i, line in enumerate(lines): s = spoken_preprocessing( line ) # this is the same input transformation as in corpus preparation batch.append(s.strip()) if len(batch) == batch_size or i == len(lines) - 1: outputs = model._infer(batch) for x in outputs: all_preds.append(x) batch = [] if len(all_preds) != len(lines): raise ValueError( "number of input lines and predictions is different: predictions=" + str(len(all_preds)) + "; lines=" + str(len(lines))) out_file = cfg.inference.out_file with open(f"{out_file}", "w", encoding="utf-8") as f_out: f_out.write("\n".join(all_preds)) logging.info(f"Predictions saved to {out_file}.")
def build_vocab_from_text(self, data_text_file): """ Learns vocabulary from a text file. Can be called multiple times to update vocabulary. """ logging.debug(f"Building vocabulary from TEXT file = {data_text_file}") # NOTE this has to be run on each text file if not os.path.exists(data_text_file): raise ValueError(f"Data file: {data_text_file} is missing") vocab = self.vocab for d in open(data_text_file, encoding="utf-8").readlines(): d = d.rstrip() tokens = self.text_to_tokens(d) logging.debug(f"Text: {d}, Tokens: {d}") for token in tokens: if token not in vocab: vocab[token] = len(vocab) sorted_vocab = sorted(vocab.items(), key=lambda k_v: k_v[1]) logging.debug(f"Vocab: {sorted_vocab}") self.vocab = vocab self._update_cache()
def get_features( queries, max_seq_length, tokenizer, pad_label=128, word_level_slots=None, ignore_extra_tokens=False, ignore_start_end=False, ): """ Convert queries (utterance, intent label and slot labels) to BERT input format """ all_subtokens = [] all_loss_mask = [] all_subtokens_mask = [] all_segment_ids = [] all_input_ids = [] all_input_mask = [] sent_lengths = [] all_slots = [] with_label = word_level_slots is not None for i, query in enumerate(queries): words = query.strip().split() subtokens = [tokenizer.cls_token] loss_mask = [1 - ignore_start_end] subtokens_mask = [0] if with_label: slots = [pad_label] for j, word in enumerate(words): word_tokens = tokenizer.text_to_tokens(word) # to handle emojis that could be neglected during tokenization if len(word.strip()) > 0 and len(word_tokens) == 0: word_tokens = [tokenizer.ids_to_tokens(tokenizer.unk_id)] subtokens.extend(word_tokens) # mask all sub-word tokens except the first token in a word # use the label for the first sub-word token as the label for the entire word to eliminate need for disambiguation loss_mask.append(1) loss_mask.extend([int(not ignore_extra_tokens)] * (len(word_tokens) - 1)) subtokens_mask.append(1) subtokens_mask.extend([0] * (len(word_tokens) - 1)) if with_label: slots.extend([word_level_slots[i][j]] * len(word_tokens)) subtokens.append(tokenizer.sep_token) loss_mask.append(1 - ignore_start_end) subtokens_mask.append(0) sent_lengths.append(len(subtokens)) all_subtokens.append(subtokens) all_loss_mask.append(loss_mask) all_subtokens_mask.append(subtokens_mask) all_input_mask.append([1] * len(subtokens)) if with_label: slots.append(pad_label) all_slots.append(slots) max_seq_length_data = max(sent_lengths) max_seq_length = min( max_seq_length, max_seq_length_data) if max_seq_length > 0 else max_seq_length_data logging.info(f'Setting max length to: {max_seq_length}') get_stats(sent_lengths) # truncate and pad samples ( all_slots, all_subtokens, all_input_mask, all_loss_mask, all_subtokens_mask, all_input_ids, all_segment_ids, ) = DialogueBERTDataset.truncate_and_pad( max_seq_length, ignore_start_end, with_label, pad_label, tokenizer, all_slots, all_subtokens, all_input_mask, all_loss_mask, all_subtokens_mask, all_input_ids, all_segment_ids, ) # log examples for debugging logging.debug("*** Some Examples of Processed Data ***") for i in range(min(len(all_input_ids), 5)): logging.debug("i: %s" % (i)) logging.debug("subtokens: %s" % " ".join(list(map(str, all_subtokens[i])))) logging.debug("loss_mask: %s" % " ".join(list(map(str, all_loss_mask[i])))) logging.debug("input_mask: %s" % " ".join(list(map(str, all_input_mask[i])))) logging.debug("subtokens_mask: %s" % " ".join(list(map(str, all_subtokens_mask[i])))) if with_label: logging.debug("slots_label: %s" % " ".join(list(map(str, all_slots[i])))) return (all_input_ids, all_segment_ids, all_input_mask, all_loss_mask, all_subtokens_mask, all_slots)
def __init__( self, *, audio_tar_filepaths: Union[str, List[str]], manifest_filepath: str, labels: List[str], featurizer, shuffle_n: int = 0, min_duration: Optional[float] = 0.1, max_duration: Optional[float] = None, trim: bool = False, load_audio: bool = True, shard_strategy: str = "scatter", global_rank: int = 0, world_size: int = 0, ): self.collection = collections.ASRSpeechLabel( manifests_files=manifest_filepath.split(','), min_duration=min_duration, max_duration=max_duration, index_by_file_id= True, # Must set this so the manifest lines can be indexed by file ID ) self.file_occurence = count_occurence(self.collection.mapping) self.featurizer = featurizer self.trim = trim self.load_audio = load_audio self.labels = labels if labels else self.collection.uniq_labels self.num_classes = len(self.labels) self.label2id, self.id2label = {}, {} for label_id, label in enumerate(self.labels): self.label2id[label] = label_id self.id2label[label_id] = label for idx in range(len(self.labels[:5])): logging.debug(" label id {} and its mapped label {}".format( idx, self.id2label[idx])) valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"`shard_strategy` must be one of {valid_shard_strategies}") if isinstance(audio_tar_filepaths, str): # Replace '(' and '[' with '{' brace_keys_open = ['(', '[', '<', '_OP_'] for bkey in brace_keys_open: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "{") # Replace ')' and ']' with '}' brace_keys_close = [')', ']', '>', '_CL_'] for bkey in brace_keys_close: if bkey in audio_tar_filepaths: audio_tar_filepaths = audio_tar_filepaths.replace( bkey, "}") # Check for distributed and partition shards accordingly if world_size > 1: if isinstance(audio_tar_filepaths, str): # Brace expand audio_tar_filepaths = list( braceexpand.braceexpand(audio_tar_filepaths)) if shard_strategy == 'scatter': logging.info( "All tarred dataset shards will be scattered evenly across all nodes." ) if len(audio_tar_filepaths) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(audio_tar_filepaths)}) is not divisible " f"by number of distributed workers ({world_size}).") begin_idx = (len(audio_tar_filepaths) // world_size) * global_rank end_idx = begin_idx + (len(audio_tar_filepaths) // world_size) audio_tar_filepaths = audio_tar_filepaths[begin_idx:end_idx] logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx) elif shard_strategy == 'replicate': logging.info( "All tarred dataset shards will be replicated across all nodes." ) else: raise ValueError( f"Invalid shard strategy ! Allowed values are : {valid_shard_strategies}" ) # Put together WebDataset self._dataset = wd.WebDataset(audio_tar_filepaths) if shuffle_n > 0: self._dataset = self._dataset.shuffle(shuffle_n) else: logging.info( "WebDataset will not shuffle files within the tar files.") self._dataset = (self._dataset.rename( audio='wav', key='__key__').to_tuple('audio', 'key').pipe( self._filter).map(f=self._build_sample))