def _default_save_to(self, save_path: str): """ Saves model instance (weights and configuration) into .nemo file. You can use "restore_from" method to fully restore instance from .nemo file. .nemo file is an archive (tar.gz) with the following: model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor model_wights.chpt - model checkpoint Args: save_path: Path to .nemo file where model instance should be saved """ with tempfile.TemporaryDirectory() as tmpdir: config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML) model_weights = path.join(tmpdir, _MODEL_WEIGHTS) if hasattr(self, 'artifacts') and self.artifacts is not None: for (conf_path, src) in self.artifacts: try: if os.path.exists(src): shutil.copy2(src, tmpdir) except Exception: logging.error( f"Could not copy artifact {src} used in {conf_path}" ) self.to_config_file(path2yaml_file=config_yaml) torch.save(self.state_dict(), model_weights) self.__make_nemo_file_from_folder(filename=save_path, source_dir=tmpdir)
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") # TODO if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") # TODO if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg["dataloader_params"]): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def segment_from_file(cls, audio_file, target_sr=None, n_segments=0, trim=False, orig_sr=None): """Grabs n_segments number of samples from audio_file randomly from the file as opposed to at a specified offset. Note that audio_file can be either the file path, or a file-like object. """ try: with sf.SoundFile(audio_file, 'r') as f: sample_rate = f.samplerate if n_segments > 0 and len(f) > n_segments: max_audio_start = len(f) - n_segments audio_start = random.randint(0, max_audio_start) f.seek(audio_start) samples = f.read(n_segments, dtype='float32') else: samples = f.read(dtype='float32') samples = samples.transpose() except RuntimeError as e: logging.error( f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`." ) samples = samples.transpose() return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
def configure_checkpointing(trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: 'DictConfig'): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ for callback in trainer.callbacks: if isinstance(callback, ModelCheckpoint): raise CheckpointMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " "to False, or remove ModelCheckpoint from the lightning trainer" ) if Path(trainer.weights_save_path) != Path.cwd(): raise CheckpointMisconfigurationError( "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" ) # Create the callback and attach it to trainer if "filepath" in params: if params.filepath is not None: logging.warning( "filepath is deprecated. Please switch to dirpath and filename instead" ) if params.dirpath is None: params.dirpath = Path(params.filepath).parent if params.filename is None: params.filename = Path(params.filepath).name with open_dict(params): del params["filepath"] if params.dirpath is None: params.dirpath = Path(log_dir / 'checkpoints') if params.filename is None: params.filename = f'{name}--{{{params.monitor}:.2f}}-{{epoch}}' if params.prefix is None: params.prefix = name NeMoModelCheckpoint.CHECKPOINT_NAME_LAST = params.filename + '-last' logging.debug(params.dirpath) logging.debug(params.filename) logging.debug(params.prefix) if "val" in params.monitor: if (trainer.max_epochs is not None and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch): logging.error( "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}" f"). It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found " "in the returned metrics. Please ensure that validation is run within trainer.max_epochs." ) elif trainer.max_steps is not None: logging.warning( "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to " f"{trainer.max_steps}. Please ensure that max_steps will run for at least " f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out." ) checkpoint_callback = NeMoModelCheckpoint(**params) checkpoint_callback.last_model_path = trainer.resume_from_checkpoint or "" trainer.callbacks.append(checkpoint_callback)
def sync_all_processes(self, status=True): """ Helper function for testing that allows proccess 0 to inform all other processes of failures. Does nothing if not using distributed training. Usage example can be seen in examples/asr/jasper_an4.py Args: status (bool): Defaults to True. If any proccess passes False, it will trigger a graceful exit on all other processes. It is assumed that the process that passed False will print an error message on its own and exit """ if self._world_size == 1: logging.info( "sync_all_processes does nothing if there is one process") return if True: # self._backend == Backend.PyTorch: import torch status_tensor = torch.cuda.IntTensor([status]) torch.distributed.all_reduce(status_tensor, op=torch.distributed.ReduceOp.MIN) if status_tensor.item() == 0: logging.error("At least one process had a failure") if status: raise ValueError( f"Process with global rank {self._global_rank} entered" " sync_all_processes with a passing status, but " "another process indicated a failure")
def convert_model_config_to_dict_config( cfg: Union['DictConfig', 'NemoConfig']) -> 'DictConfig': """ Converts its input into a standard DictConfig. Possible input values are: - DictConfig - A dataclass which is a subclass of NemoConfig Args: cfg: A dict-like object. Returns: The equivalent DictConfig """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/Omegaconf and it was not installed.") exit(1) if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg): cfg = OmegaConf.structured(cfg) if not isinstance(cfg, DictConfig): raise ValueError( f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead." ) config = OmegaConf.to_container(cfg, resolve=True) config = OmegaConf.create(config) return config
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True" ) with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!") elif cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!") if cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset": phon_mode = contextlib.nullcontext() if hasattr(self.vocab, "set_phone_prob"): phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability) with phon_mode: dataset = instantiate( cfg.dataset, text_normalizer=self.normalizer, text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, text_tokenizer=self.vocab, ) else: dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True" ) with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!") elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!") # TODO(Oktai15): remove it in 1.8.0 version if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset": dataset = instantiate(cfg.dataset, parser=self.parser) elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset": dataset = instantiate( cfg.dataset, text_normalizer=self.normalizer, text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, text_tokenizer=self.vocab, ) else: # TODO(Oktai15): remove it in 1.8.0 version dataset = instantiate(cfg.dataset) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def __init__( self, train_tensors=[], wandb_name=None, wandb_project=None, args=None, update_freq=25, ): """ Args: train_tensors: list of tensors to evaluate and log based on training batches wandb_name: wandb experiment name wandb_project: wandb project name args: argparse flags - will be logged as hyperparameters update_freq: frequency with which to log updates """ super().__init__() if not _WANDB_AVAILABLE: logging.error( "Could not import wandb. Did you install it (pip install --upgrade wandb)?" ) self._update_freq = update_freq self._train_tensors = train_tensors self._name = wandb_name self._project = wandb_project self._args = args
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): logging.info(f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism') state_dict = torch.load(restore_path) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict(state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) logging.info(f"weights restored from {restore_path}") elif os.path.isdir(restore_path): # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank(group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' logging.info(f'Restoring model parallel checkpoint from: {mp_restore_path}') state_dict = torch.load(mp_restore_path) # to load from Megatron pretrained checkpoint if 'model' in state_dict: self.language_model.load_state_dict(state_dict['model'][self._language_model_key]) else: self.load_state_dict(state_dict) else: logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint') else: logging.error(f'restore_path: {restore_path} must be a file or directory.')
def restore_weights(self, restore_path: str): """Restores module/model's weights. For model parallel checkpoints the directory structure should be restore_path/mp_rank_0X/model_optim_rng.pt Args: restore_path (str): restore_path should a file or a directory if using model parallel """ self._restore_path = restore_path if os.path.isfile(restore_path): self._load_checkpoint(restore_path) elif os.path.isdir(restore_path): # need model parallel groups to restore model parallel checkpoints if model_parallel_is_initialized(): model_parallel_rank = torch.distributed.get_rank( group=get_model_parallel_group()) mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt' self._load_checkpoint(mp_restore_path) else: logging.info( f'torch.distributed not initialized yet. Will not restore model parallel checkpoint' ) else: logging.error( f'restore_path: {restore_path} must be a file or directory.')
def configure_checkpointing( trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: Dict, ): """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint callback or if trainer.weights_save_path was passed to Trainer. """ for callback in trainer.callbacks: if isinstance(callback, ModelCheckpoint): raise CheckpointMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint " "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback " "to False, or remove ModelCheckpoint from the lightning trainer" ) if Path(trainer.weights_save_path) != Path.cwd(): raise CheckpointMisconfigurationError( "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager" ) # Create the callback and attach it to trainer if params.filepath is None: params.filepath = Path(log_dir / 'checkpoints' / f'--{{{params.monitor}:.2f}}-{{epoch}}') if params.prefix is None: params.prefix = name if "val" in params.monitor and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch: logging.error( "The checkpoint callback was told to monitor a validation value but trainer.max_epochs(" f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch})." f"It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found in the " "returned metrics. Please ensure that validation is run within trainer.max_epochs." ) checkpoint_callback = NeMoModelCheckpoint(**params) trainer.callbacks.append(checkpoint_callback)
def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"): if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig): raise ValueError(f"No dataset for {name}") if "dataloader_params" not in cfg or not isinstance( cfg.dataloader_params, DictConfig): raise ValueError(f"No dataloder_params for {name}") if shuffle_should_be: if 'shuffle' not in cfg.dataloader_params: logging.warning( f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its " "config. Manually setting to True") with open_dict(cfg.dataloader_params): cfg.dataloader_params.shuffle = True elif not cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to False!!!" ) elif not shuffle_should_be and cfg.dataloader_params.shuffle: logging.error( f"The {name} dataloader for {self} has shuffle set to True!!!") kwargs_dict = {} if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset": kwargs_dict["parser"] = self.parser dataset = instantiate(cfg.dataset, **kwargs_dict) return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
def _convert_config(cfg: 'OmegaConf'): """ Recursive function convertint the configuration from old hydra format to the new one. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/Omegaconf and it was not installed.") exit(1) # Get rid of cls -> _target_. if 'cls' in cfg and '_target_' not in cfg: cfg._target_ = cfg.pop('cls') # Get rid of params. if 'params' in cfg: params = cfg.pop('params') for param_key, param_val in params.items(): cfg[param_key] = param_val # Recursion. try: for _, sub_cfg in cfg.items(): if isinstance(sub_cfg, DictConfig): _convert_config(sub_cfg) except omegaconf_errors.OmegaConfBaseException as e: logging.warning( f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
def error_checks(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None): """ Checks that the passed trainer is compliant with NeMo and exp_manager's passed configuration. Checks that: - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP - Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True - Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP """ if HydraConfig.initialized() and get_original_cwd() != os.getcwd(): raise ValueError( "Hydra changed the working directory. This interferes with ExpManger's functionality. Please pass " "hydra.run.dir=. to your python script.") if trainer.logger is not None and (cfg.create_tensorboard_logger or cfg.create_wandb_logger): raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger, and either " f"create_tensorboard_logger: {cfg.create_tensorboard_logger} or create_wandb_logger: " f"{cfg.create_wandb_logger} was set to True. These can only be used if trainer does not already have a" " logger.") if trainer.num_nodes > 1 and not check_slurm(trainer): logging.error( "You are running multi-node training without SLURM handling the processes." " Please note that this is not tested in NeMo and could result in errors." ) if trainer.num_gpus > 1 and not isinstance( trainer.accelerator.training_type_plugin, DDPPlugin): logging.error( "You are running multi-gpu without ddp.Please note that this is not tested in NeMo and could result in " "errors.")
def get_translation(): try: time_s = time.time() langpair = request.args["langpair"] src = request.args["text"] do_moses = request.args.get('do_moses', False) if langpair in MODELS_DICT: if do_moses: result = MODELS_DICT[langpair].translate( [src], source_lang=langpair.split('-')[0], target_lang=langpair.split('-')[1] ) else: result = MODELS_DICT[langpair].translate([src]) duration = time.time() - time_s logging.info( f"Translated in {duration}. Input was: {request.args['text']} <############> Translation was: {result[0]}" ) res = {'translation': result[0]} response = flask.jsonify(res) response.headers.add('Access-Control-Allow-Origin', '*') return response else: logging.error(f"Got the following langpair: {langpair} which was not found") except Exception as ex: res = {'translation': str(ex)} response = flask.jsonify(res) response.headers.add('Access-Control-Allow-Origin', '*') return res
def check_explicit_log_dir(trainer: 'pytorch_lightning.Trainer', explicit_log_dir: [Path, str], exp_dir: str, name: str, version: str) -> (Path, str, str, str): """ Checks that the passed arguments are compatible with explicit_log_dir. Returns: log_dir (Path): the log_dir exp_dir (str): the base exp_dir without name nor version name (str): The name of the experiment version (str): The version of the experiment Raise: LoggerMisconfigurationError """ if trainer.logger is not None: raise LoggerMisconfigurationError( "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: " f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer." ) # Checking only (explicit_log_dir) vs (exp_dir and version). # The `name` will be used as the actual name of checkpoint/archive. if exp_dir or version: logging.error( f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, " f"or version: {version}. Please note that exp_dir, name, and version will be ignored." ) if is_global_rank_zero() and Path(explicit_log_dir).exists(): logging.warning( f"Exp_manager is logging to {explicit_log_dir}, but it already exists." ) return Path(explicit_log_dir), str(explicit_log_dir), "", ""
def _try_jit_compile_model(self, module, try_script): jitted_model = None if try_script: try: jitted_model = torch.jit.script(module) except Exception as e: logging.error(f"jit.script() failed!\{e}") return jitted_model
def from_file( cls, audio_file, target_sr=None, int_values=False, offset=0, duration=0, trim=False, orig_sr=None, ): """ Load a file supported by librosa and return as an AudioSegment. :param audio_file: path of file to load :param target_sr: the desired sample rate :param int_values: if true, load samples as 32-bit integers :param offset: offset in seconds when loading audio :param duration: duration in seconds when loading audio :return: numpy array of samples """ samples = None if not isinstance(audio_file, str) or os.path.splitext(audio_file)[-1] in sf_supported_formats: try: with sf.SoundFile(audio_file, 'r') as f: dtype = 'int32' if int_values else 'float32' sample_rate = f.samplerate if offset > 0: f.seek(int(offset * sample_rate)) if duration > 0: samples = f.read(int(duration * sample_rate), dtype=dtype) else: samples = f.read(dtype=dtype) samples = samples.transpose() except RuntimeError as e: logging.error( f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`. " f"NeMo will fallback to loading via pydub." ) elif isinstance(audio_file, str) and audio_file.strip()[-1] == "|": f = open_like_kaldi(audio_file, "rb") sample_rate, samples = read_kaldi(f) if offset > 0: samples = samples[int(offset * sample_rate) :] if duration > 0: samples = samples[: int(duration * sample_rate)] if not int_values: abs_max_value = np.abs(samples).max() samples = np.array(samples, dtype=np.float) / abs_max_value if samples is None: try: samples = Audio.from_file(audio_file) sample_rate = samples.frame_rate if offset > 0: # pydub does things in milliseconds seconds = offset * 1000 samples = samples[int(seconds) :] if duration > 0: seconds = duration * 1000 samples = samples[: int(seconds)] samples = np.array(samples.get_array_of_samples()) except CouldntDecodeError as e: logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{e}`.") return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
def _parse_as_cmu_dict(phoneme_dict_path=None, encoding='latin-1'): if phoneme_dict_path is None: # this part of code downloads file, but it is not rank zero guarded # Try to check if torch distributed is available, if not get global rank zero to download corpora and make # all other ranks sleep for a minute if torch.distributed.is_available( ) and torch.distributed.is_initialized(): group = torch.distributed.group.WORLD if is_global_rank_zero(): try: nltk.data.find('corpora/cmudict.zip') except LookupError: nltk.download('cmudict', quiet=True) torch.distributed.barrier(group=group) elif is_global_rank_zero(): logging.error( f"Torch distributed needs to be initialized before you initialized EnglishG2p. This class is prone to " "data access race conditions. Now downloading corpora from global rank 0. If other ranks pass this " "before rank 0, errors might result.") try: nltk.data.find('corpora/cmudict.zip') except LookupError: nltk.download('cmudict', quiet=True) else: logging.error( f"Torch distributed needs to be initialized before you initialized EnglishG2p. This class is prone to " "data access race conditions. This process is not rank 0, and now going to sleep for 1 min. If this " "rank wakes from sleep prior to rank 0 finishing downloading, errors might result." ) time.sleep(60) logging.warning( f"English g2p_dict will be used from nltk.corpus.cmudict.dict(), because phoneme_dict_path=None. " "Note that nltk.corpus.cmudict.dict() has old version (0.6) of CMUDict. " "You can use the latest official version of CMUDict (0.7b) with additional changes from NVIDIA directly from NeMo " "using the path scripts/tts_dataset_files/cmudict-0.7b_nv22.01." ) return nltk.corpus.cmudict.dict() _alt_re = re.compile(r'\([0-9]+\)') g2p_dict = {} with open(phoneme_dict_path, encoding=encoding) as file: for line in file: if len(line) and ('A' <= line[0] <= 'Z' or line[0] == "'"): parts = line.split(' ') word = re.sub(_alt_re, '', parts[0]) word = word.lower() pronunciation = parts[1].strip().split(" ") if word in g2p_dict: g2p_dict[word].append(pronunciation) else: g2p_dict[word] = [pronunciation] return g2p_dict
def __restore_from(self, path, state): if not os.path.isdir(path): if self._force_load: raise ValueError("force_load was set to True for checkpoint callback but a checkpoint was not found.") logging.warning(f"Checkpoint folder {path} not found!") else: logging.info(f"Found checkpoint folder {path}. Will attempt to restore checkpoints from it.") modules_to_restore = [] modules_to_restore_name = [] for module in AppState().modules: if module.num_weights > 0: modules_to_restore.append(module) modules_to_restore_name.append(str(module)) step_check = None try: module_checkpoints, steps = get_checkpoint_from_dir(modules_to_restore_name, path, return_steps=True) # If the steps are different, print a warning message for step in steps: if step_check is None: step_check = step elif step != step_check: logging.warning("Restoring from modules checkpoints where the training step does not match") break for mod, checkpoint in zip(modules_to_restore, module_checkpoints): mod.restore_from(checkpoint, state["local_rank"]) except (ValueError) as e: if self._force_load: raise ValueError( "force_load was set to True for checkpoint callback but a checkpoint was not found." ) logging.warning(e) logging.warning( f"Checkpoint folder {path} was present but nothing was restored. Continuing training from random " "initialization." ) return try: trainer_checkpoints, steps = get_checkpoint_from_dir(["trainer"], path, return_steps=True) if step_check is not None and step_check != steps[0]: logging.error( "The step we are restoring from the trainer checkpoint does not match one or more steps that " "are being restored from modules." ) state.restore_state_from(trainer_checkpoints[0]) except (ValueError) as e: logging.warning(e) logging.warning( "Trainer state such as optimizer state and current step/epoch was not restored. Pretrained weights" " have still been restore and fine-tuning should continue fine." ) return
def compile_helper(): """Compile helper function ar runtime. Make sure this is invoked on a single process.""" path = os.path.abspath(os.path.dirname(__file__)) ret = subprocess.run(['make', '-C', path]) if ret.returncode != 0: logging.error("Making C++ dataset helpers module failed, exiting.") import sys sys.exit(1)
def on_action_start(self): if self.global_rank is None or self.global_rank == 0: if self._wandb_name is not None or self._wandb_project is not None: if _WANDB_AVAILABLE and wandb.run is None: wandb.init(name=self._wandb_name, project=self._wandb_project) elif _WANDB_AVAILABLE and wandb.run is not None: logging.info("Re-using wandb session") else: logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?") logging.info("Will not log data to weights and biases.") self._wandb_name = None self._wandb_project = None
def on_train_start(self, state): if state["global_rank"] is None or state["global_rank"] == 0: if _WANDB_AVAILABLE and wandb.run is None: wandb.init(name=self._name, project=self._project) if self._args is not None: wandb.config.update(self._args) elif _WANDB_AVAILABLE and wandb.run is not None: logging.info("Re-using wandb session") else: logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?") logging.info("Will not log data to weights and biases.") self._step_freq = -1
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) super().__init__(cfg=cfg, trainer=trainer) self.audio_to_melspec_precessor = instantiate(cfg.preprocessor) self.encoder = instantiate(cfg.encoder) self.variance_adapter = instantiate(cfg.variance_adaptor) self.generator = instantiate(cfg.generator) self.multiperioddisc = MultiPeriodDiscriminator() self.multiscaledisc = MultiScaleDiscriminator() self.melspec_fn = instantiate(cfg.preprocessor, highfreq=None, use_grads=True) self.mel_val_loss = L1MelLoss() self.durationloss = DurationLoss() self.feat_matching_loss = FeatureMatchingLoss() self.disc_loss = DiscriminatorLoss() self.gen_loss = GeneratorLoss() self.mseloss = torch.nn.MSELoss() self.energy = cfg.add_energy_predictor self.pitch = cfg.add_pitch_predictor self.mel_loss_coeff = cfg.mel_loss_coeff self.pitch_loss_coeff = cfg.pitch_loss_coeff self.energy_loss_coeff = cfg.energy_loss_coeff self.splice_length = cfg.splice_length self.use_energy_pred = False self.use_pitch_pred = False self.log_train_images = False self.logged_real_samples = False self._tb_logger = None self.sample_rate = cfg.sample_rate self.hop_size = cfg.hop_size # Parser and mappings are used for inference only. self.parser = parsers.make_parser(name='en') if 'mappings_filepath' in cfg: mappings_filepath = cfg.get('mappings_filepath') else: logging.error( "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath." ) mappings_filepath = self.register_artifact('mappings_filepath', mappings_filepath) with open(mappings_filepath, 'r') as f: mappings = json.load(f) self.word2phones = mappings['word2phones'] self.phone2idx = mappings['phone2idx']
def _add_subconfig_keys(model_cfg: 'DictConfig', update_cfg: 'DictConfig', subconfig_key: str): """ For certain sub-configs, the default values specified by the NemoConfig class is insufficient. In order to support every potential value in the merge between the `update_cfg`, it would require explicit definition of all possible cases. An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic details - such as name and lr, but almost all require additional parameters - such as weight decay. It is impractical to create a config for every single optimizer + every single scheduler combination. In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential components. The extra values are provided via update_cfg. In order to enable the merge, we first need to update the update sub-config to incorporate the keys, with dummy temporary values (merge update config with model config). This is done on a copy of the update sub-config, as the actual override values might be overriden by the NemoConfig defaults. Then we perform a merge of this temporary sub-config with the actual override config in a later step (merge model_cfg with original update_cfg, done outside this function). Args: model_cfg: A DictConfig instantiated from the NemoConfig subclass. update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values. subconfig_key: A str key used to check and update the sub-config. Returns: A ModelPT DictConfig with additional keys added to the sub-config. """ if not _HAS_HYDRA: logging.error( "This function requires Hydra/Omegaconf and it was not installed.") exit(1) with open_dict(model_cfg.model): # Create copy of original model sub config if subconfig_key in update_cfg.model: if subconfig_key not in model_cfg.model: # create the key as a placeholder model_cfg.model[subconfig_key] = None subconfig = copy.deepcopy(model_cfg.model[subconfig_key]) update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key]) # Add the keys and update temporary values, will be updated during full merge subconfig = OmegaConf.merge(update_subconfig, subconfig) # Update sub config model_cfg.model[subconfig_key] = subconfig return model_cfg
def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) super().__init__(cfg=cfg, trainer=trainer) schema = OmegaConf.structured(FastSpeech2Config) # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes if isinstance(cfg, dict): cfg = OmegaConf.create(cfg) elif not isinstance(cfg, DictConfig): raise ValueError( f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig" ) # Ensure passed cfg is compliant with schema OmegaConf.merge(cfg, schema) self.pitch = cfg.add_pitch_predictor self.energy = cfg.add_energy_predictor self.duration_coeff = cfg.duration_coeff self.audio_to_melspec_preprocessor = instantiate( self._cfg.preprocessor) self.encoder = instantiate(self._cfg.encoder) self.mel_decoder = instantiate(self._cfg.decoder) self.variance_adapter = instantiate(self._cfg.variance_adaptor) self.loss = L2MelLoss() self.mseloss = torch.nn.MSELoss() self.durationloss = DurationLoss() self.log_train_images = False # Parser and mappings are used for inference only. self.parser = parsers.make_parser(name='en') if 'mappings_filepath' in cfg: mappings_filepath = cfg.get('mappings_filepath') else: logging.error( "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath." ) mappings_filepath = self.register_artifact('mappings_filepath', mappings_filepath) with open(mappings_filepath, 'r') as f: mappings = json.load(f) self.word2phones = mappings['word2phones'] self.phone2idx = mappings['phone2idx']
def __init__( self, step_freq: int = 100, tensors_to_log: List[Union[str, NmTensor]] = ["loss"], wandb_name: str = None, wandb_project: str = None, args=None, log_epoch: bool = True, log_lr: bool = True, ): if not _WANDB_AVAILABLE: logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?") self._step_freq = step_freq self._tensors_to_log = tensors_to_log self._name = wandb_name self._project = wandb_project self._args = args self._last_epoch_start = None self._log_epoch = log_epoch self._log_lr = log_lr
def forward(self, system_acts): """ Generates system response Args: system_acts (list): system actions in the format: [['Inform', 'Train', 'Day', 'wednesday'], []] [act, domain, slot, slot_value] Returns: system_uttr (str): generated system utterance """ action = {} for intent, domain, slot, value in system_acts: k = '-'.join([domain, intent]) action.setdefault(k, []) action[k].append([slot, value]) dialog_acts = action mode = self.mode try: if mode == 'manual': system_uttr = self._manual_generate( dialog_acts, self.manual_system_template) elif mode == 'auto': system_uttr = self._auto_generate(dialog_acts, self.auto_system_template) elif mode == 'auto_manual': template1 = self.auto_system_template template2 = self.manual_system_template system_uttr = self._auto_generate(dialog_acts, template1) if system_uttr == 'None': system_uttr = self._manual_generate(dialog_acts, template2) else: raise Exception( "Invalid mode! available mode: auto, manual, auto_manual") # truncate a system utterance with multiple questions system_uttr = self.truncate_sys_response(system_uttr) logging.info("NLG output = System reply: %s", system_uttr) return system_uttr except Exception as e: logging.error('Error in processing: %s', dialog_acts) raise e
def on_action_start(self, state): if state["global_rank"] is None or state["global_rank"] == 0: if wandb.run is None: wandb.init(job_type='train', id=self._runid, tags=['train', 'nemo'], group='train', name=self._name, project='asr', entity='cprc') if self._args is not None: logging.info('init wandb session and append args') wandb.config.update(self._args) elif wandb.run is not None: logging.info("Re-using wandb session") else: logging.error( "Could not import wandb. Did you install it (pip install --upgrade wandb)?" ) logging.info("Will not log data to weights and biases.") self._update_freq = -1