Ejemplo n.º 1
0
    def _default_save_to(self, save_path: str):
        """
        Saves model instance (weights and configuration) into .nemo file.
        You can use "restore_from" method to fully restore instance from .nemo file.

        .nemo file is an archive (tar.gz) with the following:
            model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor
            model_wights.chpt - model checkpoint

        Args:
            save_path: Path to .nemo file where model instance should be saved
        """
        with tempfile.TemporaryDirectory() as tmpdir:
            config_yaml = path.join(tmpdir, _MODEL_CONFIG_YAML)
            model_weights = path.join(tmpdir, _MODEL_WEIGHTS)
            if hasattr(self, 'artifacts') and self.artifacts is not None:
                for (conf_path, src) in self.artifacts:
                    try:
                        if os.path.exists(src):
                            shutil.copy2(src, tmpdir)
                    except Exception:
                        logging.error(
                            f"Could not copy artifact {src} used in {conf_path}"
                        )

            self.to_config_file(path2yaml_file=config_yaml)
            torch.save(self.state_dict(), model_weights)
            self.__make_nemo_file_from_folder(filename=save_path,
                                              source_dir=tmpdir)
Ejemplo n.º 2
0
    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")  # TODO
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")  # TODO
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg["dataloader_params"]):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        dataset = instantiate(cfg.dataset)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)
Ejemplo n.º 3
0
    def segment_from_file(cls,
                          audio_file,
                          target_sr=None,
                          n_segments=0,
                          trim=False,
                          orig_sr=None):
        """Grabs n_segments number of samples from audio_file randomly from the
        file as opposed to at a specified offset.

        Note that audio_file can be either the file path, or a file-like object.
        """
        try:
            with sf.SoundFile(audio_file, 'r') as f:
                sample_rate = f.samplerate
                if n_segments > 0 and len(f) > n_segments:
                    max_audio_start = len(f) - n_segments
                    audio_start = random.randint(0, max_audio_start)
                    f.seek(audio_start)
                    samples = f.read(n_segments, dtype='float32')
                else:
                    samples = f.read(dtype='float32')
            samples = samples.transpose()
        except RuntimeError as e:
            logging.error(
                f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`."
            )

        samples = samples.transpose()
        return cls(samples,
                   sample_rate,
                   target_sr=target_sr,
                   trim=trim,
                   orig_sr=orig_sr)
Ejemplo n.º 4
0
def configure_checkpointing(trainer: 'pytorch_lightning.Trainer',
                            log_dir: Path, name: str, params: 'DictConfig'):
    """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    """
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
            )
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"
        )

    # Create the callback and attach it to trainer
    if "filepath" in params:
        if params.filepath is not None:
            logging.warning(
                "filepath is deprecated. Please switch to dirpath and filename instead"
            )
            if params.dirpath is None:
                params.dirpath = Path(params.filepath).parent
            if params.filename is None:
                params.filename = Path(params.filepath).name
        with open_dict(params):
            del params["filepath"]
    if params.dirpath is None:
        params.dirpath = Path(log_dir / 'checkpoints')
    if params.filename is None:
        params.filename = f'{name}--{{{params.monitor}:.2f}}-{{epoch}}'
    if params.prefix is None:
        params.prefix = name
    NeMoModelCheckpoint.CHECKPOINT_NAME_LAST = params.filename + '-last'

    logging.debug(params.dirpath)
    logging.debug(params.filename)
    logging.debug(params.prefix)

    if "val" in params.monitor:
        if (trainer.max_epochs is not None and trainer.max_epochs != -1
                and trainer.max_epochs < trainer.check_val_every_n_epoch):
            logging.error(
                "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
                f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch}"
                f"). It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found "
                "in the returned metrics. Please ensure that validation is run within trainer.max_epochs."
            )
        elif trainer.max_steps is not None:
            logging.warning(
                "The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to "
                f"{trainer.max_steps}. Please ensure that max_steps will run for at least "
                f"{trainer.check_val_every_n_epoch} epochs to ensure that checkpointing will not error out."
            )

    checkpoint_callback = NeMoModelCheckpoint(**params)
    checkpoint_callback.last_model_path = trainer.resume_from_checkpoint or ""
    trainer.callbacks.append(checkpoint_callback)
Ejemplo n.º 5
0
    def sync_all_processes(self, status=True):
        """ Helper function for testing that allows proccess 0 to inform all
        other processes of failures. Does nothing if not using distributed
        training. Usage example can be seen in examples/asr/jasper_an4.py

        Args:
            status (bool): Defaults to True. If any proccess passes False, it
                will trigger a graceful exit on all other processes. It is
                assumed that the process that passed False will print an error
                message on its own and exit
        """
        if self._world_size == 1:
            logging.info(
                "sync_all_processes does nothing if there is one process")
            return
        if True:  # self._backend == Backend.PyTorch:
            import torch

            status_tensor = torch.cuda.IntTensor([status])
            torch.distributed.all_reduce(status_tensor,
                                         op=torch.distributed.ReduceOp.MIN)
            if status_tensor.item() == 0:
                logging.error("At least one process had a failure")
                if status:
                    raise ValueError(
                        f"Process with global rank {self._global_rank} entered"
                        " sync_all_processes with a passing status, but "
                        "another process indicated a failure")
Ejemplo n.º 6
0
def convert_model_config_to_dict_config(
        cfg: Union['DictConfig', 'NemoConfig']) -> 'DictConfig':
    """
    Converts its input into a standard DictConfig.
    Possible input values are:
    -   DictConfig
    -   A dataclass which is a subclass of NemoConfig

    Args:
        cfg: A dict-like object.

    Returns:
        The equivalent DictConfig
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/Omegaconf and it was not installed.")
        exit(1)
    if not isinstance(cfg, (OmegaConf, DictConfig)) and is_dataclass(cfg):
        cfg = OmegaConf.structured(cfg)

    if not isinstance(cfg, DictConfig):
        raise ValueError(
            f"cfg constructor argument must be of type DictConfig/dict but got {type(cfg)} instead."
        )

    config = OmegaConf.to_container(cfg, resolve=True)
    config = OmegaConf.create(config)
    return config
Ejemplo n.º 7
0
    def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True"
                )
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
        elif cfg.dataloader_params.shuffle:
            logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

        if cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
            phon_mode = contextlib.nullcontext()
            if hasattr(self.vocab, "set_phone_prob"):
                phon_mode = self.vocab.set_phone_prob(prob=None if name == "val" else self.vocab.phoneme_probability)

            with phon_mode:
                dataset = instantiate(
                    cfg.dataset,
                    text_normalizer=self.normalizer,
                    text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
                    text_tokenizer=self.vocab,
                )
        else:
            dataset = instantiate(cfg.dataset)

        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
Ejemplo n.º 8
0
    def __setup_dataloader_from_config(self, cfg, shuffle_should_be: bool = True, name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True"
                )
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(f"The {name} dataloader for {self} has shuffle set to False!!!")
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(f"The {name} dataloader for {self} has shuffle set to True!!!")

        # TODO(Oktai15): remove it in 1.8.0 version
        if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
            dataset = instantiate(cfg.dataset, parser=self.parser)
        elif cfg.dataset._target_ == "nemo.collections.tts.torch.data.TTSDataset":
            dataset = instantiate(
                cfg.dataset,
                text_normalizer=self.normalizer,
                text_normalizer_call_kwargs=self.text_normalizer_call_kwargs,
                text_tokenizer=self.vocab,
            )
        else:
            # TODO(Oktai15): remove it in 1.8.0 version
            dataset = instantiate(cfg.dataset)

        return torch.utils.data.DataLoader(dataset, collate_fn=dataset.collate_fn, **cfg.dataloader_params)
Ejemplo n.º 9
0
    def __init__(
        self,
        train_tensors=[],
        wandb_name=None,
        wandb_project=None,
        args=None,
        update_freq=25,
    ):
        """
        Args:
            train_tensors: list of tensors to evaluate and log based on training batches
            wandb_name: wandb experiment name
            wandb_project: wandb project name
            args: argparse flags - will be logged as hyperparameters
            update_freq: frequency with which to log updates
        """
        super().__init__()

        if not _WANDB_AVAILABLE:
            logging.error(
                "Could not import wandb. Did you install it (pip install --upgrade wandb)?"
            )

        self._update_freq = update_freq
        self._train_tensors = train_tensors
        self._name = wandb_name
        self._project = wandb_project
        self._args = args
Ejemplo n.º 10
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path
        if os.path.isfile(restore_path):
            logging.info(f'restore_path: {restore_path} is a file. Assuming no megatron model parallelism')
            state_dict = torch.load(restore_path)
            # to load from Megatron pretrained checkpoint
            if 'model' in state_dict:
                self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
            else:
                self.load_state_dict(state_dict)
            logging.info(f"weights restored from {restore_path}")
        elif os.path.isdir(restore_path):
            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                logging.info(f'Restoring model parallel checkpoint from: {mp_restore_path}')
                state_dict = torch.load(mp_restore_path)
                # to load from Megatron pretrained checkpoint
                if 'model' in state_dict:
                    self.language_model.load_state_dict(state_dict['model'][self._language_model_key])
                else:
                    self.load_state_dict(state_dict)
            else:
                logging.info(f'torch.distributed not initialized yet. Will not restore model parallel checkpoint')
        else:
            logging.error(f'restore_path: {restore_path} must be a file or directory.')
Ejemplo n.º 11
0
    def restore_weights(self, restore_path: str):
        """Restores module/model's weights.
           For model parallel checkpoints the directory structure
           should be restore_path/mp_rank_0X/model_optim_rng.pt

        Args:
            restore_path (str): restore_path should a file or a directory if using model parallel
        """
        self._restore_path = restore_path

        if os.path.isfile(restore_path):
            self._load_checkpoint(restore_path)
        elif os.path.isdir(restore_path):
            # need model parallel groups to restore model parallel checkpoints
            if model_parallel_is_initialized():
                model_parallel_rank = torch.distributed.get_rank(
                    group=get_model_parallel_group())
                mp_restore_path = f'{restore_path}/mp_rank_{model_parallel_rank:02d}/model_optim_rng.pt'
                self._load_checkpoint(mp_restore_path)
            else:
                logging.info(
                    f'torch.distributed not initialized yet. Will not restore model parallel checkpoint'
                )
        else:
            logging.error(
                f'restore_path: {restore_path} must be a file or directory.')
Ejemplo n.º 12
0
def configure_checkpointing(
    trainer: 'pytorch_lightning.Trainer', log_dir: Path, name: str, params: Dict,
):
    """ Adds ModelCheckpoint to trainer. Raises CheckpointMisconfigurationError if trainer already has a ModelCheckpoint
    callback or if trainer.weights_save_path was passed to Trainer.
    """
    for callback in trainer.callbacks:
        if isinstance(callback, ModelCheckpoint):
            raise CheckpointMisconfigurationError(
                "The pytorch lightning trainer that was passed to exp_manager contained a ModelCheckpoint "
                "and create_checkpoint_callback was set to True. Please either set create_checkpoint_callback "
                "to False, or remove ModelCheckpoint from the lightning trainer"
            )
    if Path(trainer.weights_save_path) != Path.cwd():
        raise CheckpointMisconfigurationError(
            "The pytorch lightning was passed weights_save_path. This variable is ignored by exp_manager"
        )

    # Create the callback and attach it to trainer
    if params.filepath is None:
        params.filepath = Path(log_dir / 'checkpoints' / f'--{{{params.monitor}:.2f}}-{{epoch}}')
    if params.prefix is None:
        params.prefix = name

    if "val" in params.monitor and trainer.max_epochs != -1 and trainer.max_epochs < trainer.check_val_every_n_epoch:
        logging.error(
            "The checkpoint callback was told to monitor a validation value but trainer.max_epochs("
            f"{trainer.max_epochs}) was less than trainer.check_val_every_n_epoch({trainer.check_val_every_n_epoch})."
            f"It is very likely this run will fail with ModelCheckpoint(monitor='{params.monitor}') not found in the "
            "returned metrics. Please ensure that validation is run within trainer.max_epochs."
        )

    checkpoint_callback = NeMoModelCheckpoint(**params)
    trainer.callbacks.append(checkpoint_callback)
Ejemplo n.º 13
0
    def __setup_dataloader_from_config(self,
                                       cfg,
                                       shuffle_should_be: bool = True,
                                       name: str = "train"):
        if "dataset" not in cfg or not isinstance(cfg.dataset, DictConfig):
            raise ValueError(f"No dataset for {name}")
        if "dataloader_params" not in cfg or not isinstance(
                cfg.dataloader_params, DictConfig):
            raise ValueError(f"No dataloder_params for {name}")
        if shuffle_should_be:
            if 'shuffle' not in cfg.dataloader_params:
                logging.warning(
                    f"Shuffle should be set to True for {self}'s {name} dataloader but was not found in its "
                    "config. Manually setting to True")
                with open_dict(cfg.dataloader_params):
                    cfg.dataloader_params.shuffle = True
            elif not cfg.dataloader_params.shuffle:
                logging.error(
                    f"The {name} dataloader for {self} has shuffle set to False!!!"
                )
        elif not shuffle_should_be and cfg.dataloader_params.shuffle:
            logging.error(
                f"The {name} dataloader for {self} has shuffle set to True!!!")

        kwargs_dict = {}
        if cfg.dataset._target_ == "nemo.collections.asr.data.audio_to_text.FastPitchDataset":
            kwargs_dict["parser"] = self.parser
        dataset = instantiate(cfg.dataset, **kwargs_dict)
        return torch.utils.data.DataLoader(dataset,
                                           collate_fn=dataset.collate_fn,
                                           **cfg.dataloader_params)
Ejemplo n.º 14
0
def _convert_config(cfg: 'OmegaConf'):
    """ Recursive function convertint the configuration from old hydra format to the new one. """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/Omegaconf and it was not installed.")
        exit(1)

    # Get rid of cls -> _target_.
    if 'cls' in cfg and '_target_' not in cfg:
        cfg._target_ = cfg.pop('cls')

    # Get rid of params.
    if 'params' in cfg:
        params = cfg.pop('params')
        for param_key, param_val in params.items():
            cfg[param_key] = param_val

    # Recursion.
    try:
        for _, sub_cfg in cfg.items():
            if isinstance(sub_cfg, DictConfig):
                _convert_config(sub_cfg)
    except omegaconf_errors.OmegaConfBaseException as e:
        logging.warning(
            f"Skipped conversion for config/subconfig:\n{cfg}\n Reason: {e}.")
Ejemplo n.º 15
0
def error_checks(trainer: 'pytorch_lightning.Trainer',
                 cfg: Optional[Union[DictConfig, Dict]] = None):
    """
    Checks that the passed trainer is compliant with NeMo and exp_manager's passed configuration. Checks that:
        - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP
        - Throws error when trainer has loggers defined but create_tensorboard_logger or create_WandB_logger is True
        - Prints error messages when 1) run on multi-node and not Slurm, and 2) run on multi-gpu without DDP
    """
    if HydraConfig.initialized() and get_original_cwd() != os.getcwd():
        raise ValueError(
            "Hydra changed the working directory. This interferes with ExpManger's functionality. Please pass "
            "hydra.run.dir=. to your python script.")
    if trainer.logger is not None and (cfg.create_tensorboard_logger
                                       or cfg.create_wandb_logger):
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger, and either "
            f"create_tensorboard_logger: {cfg.create_tensorboard_logger} or create_wandb_logger: "
            f"{cfg.create_wandb_logger} was set to True. These can only be used if trainer does not already have a"
            " logger.")
    if trainer.num_nodes > 1 and not check_slurm(trainer):
        logging.error(
            "You are running multi-node training without SLURM handling the processes."
            " Please note that this is not tested in NeMo and could result in errors."
        )
    if trainer.num_gpus > 1 and not isinstance(
            trainer.accelerator.training_type_plugin, DDPPlugin):
        logging.error(
            "You are running multi-gpu without ddp.Please note that this is not tested in NeMo and could result in "
            "errors.")
Ejemplo n.º 16
0
def get_translation():
    try:
        time_s = time.time()
        langpair = request.args["langpair"]
        src = request.args["text"]
        do_moses = request.args.get('do_moses', False)
        if langpair in MODELS_DICT:
            if do_moses:
                result = MODELS_DICT[langpair].translate(
                    [src], source_lang=langpair.split('-')[0], target_lang=langpair.split('-')[1]
                )
            else:
                result = MODELS_DICT[langpair].translate([src])

            duration = time.time() - time_s
            logging.info(
                f"Translated in {duration}. Input was: {request.args['text']} <############> Translation was: {result[0]}"
            )
            res = {'translation': result[0]}
            response = flask.jsonify(res)
            response.headers.add('Access-Control-Allow-Origin', '*')
            return response

        else:
            logging.error(f"Got the following langpair: {langpair} which was not found")
    except Exception as ex:
        res = {'translation': str(ex)}
        response = flask.jsonify(res)
        response.headers.add('Access-Control-Allow-Origin', '*')
        return res
Ejemplo n.º 17
0
def check_explicit_log_dir(trainer: 'pytorch_lightning.Trainer',
                           explicit_log_dir: [Path, str], exp_dir: str,
                           name: str, version: str) -> (Path, str, str, str):
    """ Checks that the passed arguments are compatible with explicit_log_dir.

    Returns:
        log_dir (Path): the log_dir
        exp_dir (str): the base exp_dir without name nor version
        name (str): The name of the experiment
        version (str): The version of the experiment

    Raise:
        LoggerMisconfigurationError
    """
    if trainer.logger is not None:
        raise LoggerMisconfigurationError(
            "The pytorch lightning trainer that was passed to exp_manager contained a logger and explicit_log_dir: "
            f"{explicit_log_dir} was pass to exp_manager. Please remove the logger from the lightning trainer."
        )
    # Checking only (explicit_log_dir) vs (exp_dir and version).
    # The `name` will be used as the actual name of checkpoint/archive.
    if exp_dir or version:
        logging.error(
            f"exp_manager received explicit_log_dir: {explicit_log_dir} and at least one of exp_dir: {exp_dir}, "
            f"or version: {version}. Please note that exp_dir, name, and version will be ignored."
        )
    if is_global_rank_zero() and Path(explicit_log_dir).exists():
        logging.warning(
            f"Exp_manager is logging to {explicit_log_dir}, but it already exists."
        )
    return Path(explicit_log_dir), str(explicit_log_dir), "", ""
Ejemplo n.º 18
0
 def _try_jit_compile_model(self, module, try_script):
     jitted_model = None
     if try_script:
         try:
             jitted_model = torch.jit.script(module)
         except Exception as e:
             logging.error(f"jit.script() failed!\{e}")
     return jitted_model
Ejemplo n.º 19
0
    def from_file(
        cls, audio_file, target_sr=None, int_values=False, offset=0, duration=0, trim=False, orig_sr=None,
    ):
        """
        Load a file supported by librosa and return as an AudioSegment.
        :param audio_file: path of file to load
        :param target_sr: the desired sample rate
        :param int_values: if true, load samples as 32-bit integers
        :param offset: offset in seconds when loading audio
        :param duration: duration in seconds when loading audio
        :return: numpy array of samples
        """
        samples = None
        if not isinstance(audio_file, str) or os.path.splitext(audio_file)[-1] in sf_supported_formats:
            try:
                with sf.SoundFile(audio_file, 'r') as f:
                    dtype = 'int32' if int_values else 'float32'
                    sample_rate = f.samplerate
                    if offset > 0:
                        f.seek(int(offset * sample_rate))
                    if duration > 0:
                        samples = f.read(int(duration * sample_rate), dtype=dtype)
                    else:
                        samples = f.read(dtype=dtype)
                samples = samples.transpose()
            except RuntimeError as e:
                logging.error(
                    f"Loading {audio_file} via SoundFile raised RuntimeError: `{e}`. "
                    f"NeMo will fallback to loading via pydub."
                )
        elif isinstance(audio_file, str) and audio_file.strip()[-1] == "|":
            f = open_like_kaldi(audio_file, "rb")
            sample_rate, samples = read_kaldi(f)
            if offset > 0:
                samples = samples[int(offset * sample_rate) :]
            if duration > 0:
                samples = samples[: int(duration * sample_rate)]
            if not int_values:
                abs_max_value = np.abs(samples).max()
                samples = np.array(samples, dtype=np.float) / abs_max_value

        if samples is None:
            try:
                samples = Audio.from_file(audio_file)
                sample_rate = samples.frame_rate
                if offset > 0:
                    # pydub does things in milliseconds
                    seconds = offset * 1000
                    samples = samples[int(seconds) :]
                if duration > 0:
                    seconds = duration * 1000
                    samples = samples[: int(seconds)]
                samples = np.array(samples.get_array_of_samples())
            except CouldntDecodeError as e:
                logging.error(f"Loading {audio_file} via pydub raised CouldntDecodeError: `{e}`.")

        return cls(samples, sample_rate, target_sr=target_sr, trim=trim, orig_sr=orig_sr)
Ejemplo n.º 20
0
Archivo: g2ps.py Proyecto: ggrunin/NeMo
    def _parse_as_cmu_dict(phoneme_dict_path=None, encoding='latin-1'):
        if phoneme_dict_path is None:
            # this part of code downloads file, but it is not rank zero guarded
            # Try to check if torch distributed is available, if not get global rank zero to download corpora and make
            # all other ranks sleep for a minute
            if torch.distributed.is_available(
            ) and torch.distributed.is_initialized():
                group = torch.distributed.group.WORLD
                if is_global_rank_zero():
                    try:
                        nltk.data.find('corpora/cmudict.zip')
                    except LookupError:
                        nltk.download('cmudict', quiet=True)
                torch.distributed.barrier(group=group)
            elif is_global_rank_zero():
                logging.error(
                    f"Torch distributed needs to be initialized before you initialized EnglishG2p. This class is prone to "
                    "data access race conditions. Now downloading corpora from global rank 0. If other ranks pass this "
                    "before rank 0, errors might result.")
                try:
                    nltk.data.find('corpora/cmudict.zip')
                except LookupError:
                    nltk.download('cmudict', quiet=True)
            else:
                logging.error(
                    f"Torch distributed needs to be initialized before you initialized EnglishG2p. This class is prone to "
                    "data access race conditions. This process is not rank 0, and now going to sleep for 1 min. If this "
                    "rank wakes from sleep prior to rank 0 finishing downloading, errors might result."
                )
                time.sleep(60)

            logging.warning(
                f"English g2p_dict will be used from nltk.corpus.cmudict.dict(), because phoneme_dict_path=None. "
                "Note that nltk.corpus.cmudict.dict() has old version (0.6) of CMUDict. "
                "You can use the latest official version of CMUDict (0.7b) with additional changes from NVIDIA directly from NeMo "
                "using the path scripts/tts_dataset_files/cmudict-0.7b_nv22.01."
            )

            return nltk.corpus.cmudict.dict()

        _alt_re = re.compile(r'\([0-9]+\)')
        g2p_dict = {}
        with open(phoneme_dict_path, encoding=encoding) as file:
            for line in file:
                if len(line) and ('A' <= line[0] <= 'Z' or line[0] == "'"):
                    parts = line.split('  ')
                    word = re.sub(_alt_re, '', parts[0])
                    word = word.lower()

                    pronunciation = parts[1].strip().split(" ")
                    if word in g2p_dict:
                        g2p_dict[word].append(pronunciation)
                    else:
                        g2p_dict[word] = [pronunciation]
        return g2p_dict
Ejemplo n.º 21
0
    def __restore_from(self, path, state):
        if not os.path.isdir(path):
            if self._force_load:
                raise ValueError("force_load was set to True for checkpoint callback but a checkpoint was not found.")
            logging.warning(f"Checkpoint folder {path} not found!")
        else:
            logging.info(f"Found checkpoint folder {path}. Will attempt to restore checkpoints from it.")
            modules_to_restore = []
            modules_to_restore_name = []
            for module in AppState().modules:
                if module.num_weights > 0:
                    modules_to_restore.append(module)
                    modules_to_restore_name.append(str(module))
            step_check = None
            try:
                module_checkpoints, steps = get_checkpoint_from_dir(modules_to_restore_name, path, return_steps=True)

                # If the steps are different, print a warning message
                for step in steps:
                    if step_check is None:
                        step_check = step
                    elif step != step_check:
                        logging.warning("Restoring from modules checkpoints where the training step does not match")
                        break

                for mod, checkpoint in zip(modules_to_restore, module_checkpoints):
                    mod.restore_from(checkpoint, state["local_rank"])
            except (ValueError) as e:
                if self._force_load:
                    raise ValueError(
                        "force_load was set to True for checkpoint callback but a checkpoint was not found."
                    )
                logging.warning(e)
                logging.warning(
                    f"Checkpoint folder {path} was present but nothing was restored. Continuing training from random "
                    "initialization."
                )
                return

            try:
                trainer_checkpoints, steps = get_checkpoint_from_dir(["trainer"], path, return_steps=True)
                if step_check is not None and step_check != steps[0]:
                    logging.error(
                        "The step we are restoring from the trainer checkpoint does not match one or more steps that "
                        "are being restored from modules."
                    )
                state.restore_state_from(trainer_checkpoints[0])
            except (ValueError) as e:
                logging.warning(e)
                logging.warning(
                    "Trainer state such as optimizer state and current step/epoch was not restored. Pretrained weights"
                    " have still been restore and fine-tuning should continue fine."
                )
                return
Ejemplo n.º 22
0
def compile_helper():
    """Compile helper function ar runtime. Make sure this
    is invoked on a single process."""

    path = os.path.abspath(os.path.dirname(__file__))
    ret = subprocess.run(['make', '-C', path])
    if ret.returncode != 0:
        logging.error("Making C++ dataset helpers module failed, exiting.")
        import sys

        sys.exit(1)
 def on_action_start(self):
     if self.global_rank is None or self.global_rank == 0:
         if self._wandb_name is not None or self._wandb_project is not None:
             if _WANDB_AVAILABLE and wandb.run is None:
                 wandb.init(name=self._wandb_name, project=self._wandb_project)
             elif _WANDB_AVAILABLE and wandb.run is not None:
                 logging.info("Re-using wandb session")
             else:
                 logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
                 logging.info("Will not log data to weights and biases.")
                 self._wandb_name = None
                 self._wandb_project = None
Ejemplo n.º 24
0
 def on_train_start(self, state):
     if state["global_rank"] is None or state["global_rank"] == 0:
         if _WANDB_AVAILABLE and wandb.run is None:
             wandb.init(name=self._name, project=self._project)
             if self._args is not None:
                 wandb.config.update(self._args)
         elif _WANDB_AVAILABLE and wandb.run is not None:
             logging.info("Re-using wandb session")
         else:
             logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
             logging.info("Will not log data to weights and biases.")
             self._step_freq = -1
Ejemplo n.º 25
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        self.audio_to_melspec_precessor = instantiate(cfg.preprocessor)
        self.encoder = instantiate(cfg.encoder)
        self.variance_adapter = instantiate(cfg.variance_adaptor)

        self.generator = instantiate(cfg.generator)
        self.multiperioddisc = MultiPeriodDiscriminator()
        self.multiscaledisc = MultiScaleDiscriminator()

        self.melspec_fn = instantiate(cfg.preprocessor,
                                      highfreq=None,
                                      use_grads=True)
        self.mel_val_loss = L1MelLoss()
        self.durationloss = DurationLoss()
        self.feat_matching_loss = FeatureMatchingLoss()
        self.disc_loss = DiscriminatorLoss()
        self.gen_loss = GeneratorLoss()
        self.mseloss = torch.nn.MSELoss()

        self.energy = cfg.add_energy_predictor
        self.pitch = cfg.add_pitch_predictor
        self.mel_loss_coeff = cfg.mel_loss_coeff
        self.pitch_loss_coeff = cfg.pitch_loss_coeff
        self.energy_loss_coeff = cfg.energy_loss_coeff
        self.splice_length = cfg.splice_length

        self.use_energy_pred = False
        self.use_pitch_pred = False
        self.log_train_images = False
        self.logged_real_samples = False
        self._tb_logger = None
        self.sample_rate = cfg.sample_rate
        self.hop_size = cfg.hop_size

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        if 'mappings_filepath' in cfg:
            mappings_filepath = cfg.get('mappings_filepath')
        else:
            logging.error(
                "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
            )
        mappings_filepath = self.register_artifact('mappings_filepath',
                                                   mappings_filepath)
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']
Ejemplo n.º 26
0
def _add_subconfig_keys(model_cfg: 'DictConfig', update_cfg: 'DictConfig',
                        subconfig_key: str):
    """
    For certain sub-configs, the default values specified by the NemoConfig class is insufficient.
    In order to support every potential value in the merge between the `update_cfg`, it would require
    explicit definition of all possible cases.

    An example of such a case is Optimizers, and their equivalent Schedulers. All optimizers share a few basic
    details - such as name and lr, but almost all require additional parameters - such as weight decay.
    It is impractical to create a config for every single optimizer + every single scheduler combination.

    In such a case, we perform a dual merge. The Optim and Sched Dataclass contain the bare minimum essential
    components. The extra values are provided via update_cfg.

    In order to enable the merge, we first need to update the update sub-config to incorporate the keys,
    with dummy temporary values (merge update config with model config). This is done on a copy of the
    update sub-config, as the actual override values might be overriden by the NemoConfig defaults.

    Then we perform a merge of this temporary sub-config with the actual override config in a later step
    (merge model_cfg with original update_cfg, done outside this function).

    Args:
        model_cfg: A DictConfig instantiated from the NemoConfig subclass.
        update_cfg: A DictConfig that mirrors the structure of `model_cfg`, used to update its default values.
        subconfig_key: A str key used to check and update the sub-config.

    Returns:
        A ModelPT DictConfig with additional keys added to the sub-config.
    """
    if not _HAS_HYDRA:
        logging.error(
            "This function requires Hydra/Omegaconf and it was not installed.")
        exit(1)
    with open_dict(model_cfg.model):
        # Create copy of original model sub config
        if subconfig_key in update_cfg.model:
            if subconfig_key not in model_cfg.model:
                # create the key as a placeholder
                model_cfg.model[subconfig_key] = None

            subconfig = copy.deepcopy(model_cfg.model[subconfig_key])
            update_subconfig = copy.deepcopy(update_cfg.model[subconfig_key])

            # Add the keys and update temporary values, will be updated during full merge
            subconfig = OmegaConf.merge(update_subconfig, subconfig)
            # Update sub config
            model_cfg.model[subconfig_key] = subconfig

    return model_cfg
Ejemplo n.º 27
0
    def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None):
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        super().__init__(cfg=cfg, trainer=trainer)

        schema = OmegaConf.structured(FastSpeech2Config)
        # ModelPT ensures that cfg is a DictConfig, but do this second check in case ModelPT changes
        if isinstance(cfg, dict):
            cfg = OmegaConf.create(cfg)
        elif not isinstance(cfg, DictConfig):
            raise ValueError(
                f"cfg was type: {type(cfg)}. Expected either a dict or a DictConfig"
            )
        # Ensure passed cfg is compliant with schema
        OmegaConf.merge(cfg, schema)

        self.pitch = cfg.add_pitch_predictor
        self.energy = cfg.add_energy_predictor
        self.duration_coeff = cfg.duration_coeff

        self.audio_to_melspec_preprocessor = instantiate(
            self._cfg.preprocessor)
        self.encoder = instantiate(self._cfg.encoder)
        self.mel_decoder = instantiate(self._cfg.decoder)
        self.variance_adapter = instantiate(self._cfg.variance_adaptor)
        self.loss = L2MelLoss()
        self.mseloss = torch.nn.MSELoss()
        self.durationloss = DurationLoss()

        self.log_train_images = False

        # Parser and mappings are used for inference only.
        self.parser = parsers.make_parser(name='en')
        if 'mappings_filepath' in cfg:
            mappings_filepath = cfg.get('mappings_filepath')
        else:
            logging.error(
                "ERROR: You must specify a mappings.json file in the config file under model.mappings_filepath."
            )
        mappings_filepath = self.register_artifact('mappings_filepath',
                                                   mappings_filepath)
        with open(mappings_filepath, 'r') as f:
            mappings = json.load(f)
            self.word2phones = mappings['word2phones']
            self.phone2idx = mappings['phone2idx']
Ejemplo n.º 28
0
 def __init__(
     self,
     step_freq: int = 100,
     tensors_to_log: List[Union[str, NmTensor]] = ["loss"],
     wandb_name: str = None,
     wandb_project: str = None,
     args=None,
     log_epoch: bool = True,
     log_lr: bool = True,
 ):
     if not _WANDB_AVAILABLE:
         logging.error("Could not import wandb. Did you install it (pip install --upgrade wandb)?")
     self._step_freq = step_freq
     self._tensors_to_log = tensors_to_log
     self._name = wandb_name
     self._project = wandb_project
     self._args = args
     self._last_epoch_start = None
     self._log_epoch = log_epoch
     self._log_lr = log_lr
Ejemplo n.º 29
0
    def forward(self, system_acts):
        """
        Generates system response 
        Args:
            system_acts (list): system actions in the format: [['Inform', 'Train', 'Day', 'wednesday'], []] [act, domain, slot, slot_value]
        Returns:
            system_uttr (str): generated system utterance 
        """
        action = {}
        for intent, domain, slot, value in system_acts:
            k = '-'.join([domain, intent])
            action.setdefault(k, [])
            action[k].append([slot, value])
        dialog_acts = action
        mode = self.mode
        try:
            if mode == 'manual':
                system_uttr = self._manual_generate(
                    dialog_acts, self.manual_system_template)

            elif mode == 'auto':
                system_uttr = self._auto_generate(dialog_acts,
                                                  self.auto_system_template)

            elif mode == 'auto_manual':
                template1 = self.auto_system_template
                template2 = self.manual_system_template
                system_uttr = self._auto_generate(dialog_acts, template1)

                if system_uttr == 'None':
                    system_uttr = self._manual_generate(dialog_acts, template2)
            else:
                raise Exception(
                    "Invalid mode! available mode: auto, manual, auto_manual")
            # truncate a system utterance with multiple questions
            system_uttr = self.truncate_sys_response(system_uttr)
            logging.info("NLG output = System reply: %s", system_uttr)
            return system_uttr
        except Exception as e:
            logging.error('Error in processing: %s', dialog_acts)
            raise e
Ejemplo n.º 30
0
 def on_action_start(self, state):
     if state["global_rank"] is None or state["global_rank"] == 0:
         if wandb.run is None:
             wandb.init(job_type='train',
                        id=self._runid,
                        tags=['train', 'nemo'],
                        group='train',
                        name=self._name,
                        project='asr',
                        entity='cprc')
             if self._args is not None:
                 logging.info('init wandb session and append args')
                 wandb.config.update(self._args)
         elif wandb.run is not None:
             logging.info("Re-using wandb session")
         else:
             logging.error(
                 "Could not import wandb. Did you install it (pip install --upgrade wandb)?"
             )
             logging.info("Will not log data to weights and biases.")
             self._update_freq = -1