Beispiel #1
0
def get_registered_adapter(cls: Union[str, type]) -> Optional[AdapterRegistryInfo]:
    """
    Resolves a provided `cls` (whether str path to class, a registered base or an adapter class)
    to obtain the metadata for the adapter.

    Args:
        cls: Can be a str (absolute path to a class), a base class or an adapter class (which have already
            been registered).

    Returns:
        A AdapterRegistryInfo object if it could resolve successfully, otherwise None.
    """
    global ADAPTER_REGISTRY
    if isinstance(cls, str):
        cls = model_utils.import_class_by_path(cls)

    # If an adapter class was provided, de-reference its base class
    if hasattr(cls, '_meta_base_class'):
        cls = cls._meta_base_class

    class_path = f'{cls.__module__}.{cls.__name__}'

    # If base class, check registry
    if class_path in ADAPTER_REGISTRY:
        return ADAPTER_REGISTRY[class_path]

    return None
Beispiel #2
0
def _get_class_from_path(domain, subdomains, imp):
    path = _build_import_path(domain, subdomains, imp)

    class_ = None
    result = None

    try:
        class_ = model_utils.import_class_by_path(path)

        if inspect.isclass(class_):
            # Is class wrpped in a wrapt.decorator a the class level? Unwrap for checks.
            if isinstance(class_, wrapt.FunctionWrapper):
                class_ = class_.__wrapped__

            # Subclass tests
            if issubclass(class_, (Model, torch.nn.Module)):
                result = class_
        else:
            class_ = None

        error = None

    except Exception:
        error = traceback.format_exc()

    return class_, result, error
Beispiel #3
0
def initialize_model(model_name):
    # load model
    if model_name not in MODEL_CACHE:
        if '.nemo' in model_name:
            # use local model
            model_name_no_ext = os.path.splitext(model_name)[0]
            model_path = os.path.join('models', model_name_no_ext, model_name)

            # Extract config
            model_cfg = nemo_asr.models.ASRModel.restore_from(
                restore_path=model_path, return_config=True)
            classpath = model_cfg.target  # original class path
            imported_class = model_utils.import_class_by_path(
                classpath)  # type: ASRModel
            logging.info(f"Restoring local model : {imported_class.__name__}")

            # load model from checkpoint
            model = imported_class.restore_from(
                restore_path=model_path, map_location='cpu')  # type: ASRModel

        else:
            # use pretrained model
            model = nemo_asr.models.ASRModel.from_pretrained(
                model_name, map_location='cpu')

        model.freeze()

        # cache model
        MODEL_CACHE[model_name] = model

    model = MODEL_CACHE[model_name]
    return model
Beispiel #4
0
    def from_config_dict(cls, config: 'DictConfig'):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if _HAS_HYDRA:
            if isinstance(config, DictConfig):
                config = OmegaConf.to_container(config, resolve=True)
                config = OmegaConf.create(config)
                OmegaConf.set_struct(config, True)

            config = maybe_update_config_version(config)

        # Hydra 0.x API
        if ('cls' in config
                or 'target' in config) and 'params' in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        # Hydra 1.x API
        elif '_target_' in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        else:
            instance = None
            imported_cls_tb = None
            # Attempt class path resolution from config `target` class (if it exists)
            if 'target' in config:
                target_cls = config[
                    "target"]  # No guarantee that this is a omegaconf class
                imported_cls = None
                try:
                    # try to import the target class
                    imported_cls = import_class_by_path(target_cls)
                except Exception:
                    imported_cls_tb = traceback.format_exc()

                # try instantiating model with target class
                if imported_cls is not None:
                    # if calling class (cls) is subclass of imported class,
                    # use subclass instead
                    if issubclass(cls, imported_cls):
                        imported_cls = cls

                    try:
                        instance = imported_cls(cfg=config)
                    except Exception:
                        imported_cls_tb = traceback.format_exc()
                        instance = None

            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                if imported_cls_tb is not None:
                    logging.debug(
                        f"Model instantiation from target class {target_cls} failed with following error.\n"
                        f"Falling back to `cls`.\n"
                        f"{imported_cls_tb}")
                instance = cls(cfg=config)

        if not hasattr(instance, '_cfg'):
            instance._cfg = config
        return instance
Beispiel #5
0
    def from_config_dict(cls, config: DictConfig):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if isinstance(config, DictConfig):
            config = OmegaConf.to_container(config, resolve=True)
            config = OmegaConf.create(config)
            OmegaConf.set_struct(config, True)

        config = maybe_update_config_version(config)

        # Hydra 0.x API
        if ('cls' in config or 'target' in config) and 'params' in config:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        # Hydra 1.x API
        elif '_target_' in config:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        else:
            instance = None

            # Attempt class path resolution from config `target` class (if it exists)
            if 'target' in config:
                target_cls = config.target
                imported_cls = None
                try:
                    # try to import the target class
                    imported_cls = import_class_by_path(target_cls)
                except (ImportError, ModuleNotFoundError):
                    logging.debug(f'Target class `{target_cls}` could not be imported, falling back to original cls')

                # try instantiating model with target class
                if imported_cls is not None:
                    try:
                        instance = imported_cls(cfg=config)
                    except Exception:
                        imported_cls_tb = traceback.format_exc()
                        logging.debug(
                            f"Model instantiation from target class failed with following error.\n"
                            f"Falling back to `cls`.\n"
                            f"{imported_cls_tb}"
                        )
                        instance = None

            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                instance = cls(cfg=config)

        if not hasattr(instance, '_cfg'):
            instance._cfg = config
        return instance
Beispiel #6
0
def main(cfg: TranscriptionConfig):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    if cfg.model_path is None and cfg.pretrained_name is None:
        raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None !")

    # setup gpu
    if cfg.cuda is None:
        cfg.cuda = torch.cuda.is_available()

    if type(cfg.cuda) == int:
        device_id = int(cfg.cuda)
    else:
        device_id = 0

    device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu')

    # setup model
    if cfg.model_path is not None:
        # restore model from .nemo file path
        model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True)
        classpath = model_cfg.target  # original class path
        imported_class = model_utils.import_class_by_path(classpath)  # type: ASRModel
        logging.info(f"Restoring model : {imported_class.__name__}")

        asr_model = imported_class.restore_from(restore_path=cfg.model_path, map_location=device)  # type: ASRModel
    else:
        # restore model by name
        asr_model = ASRModel.from_pretrained(model_name=cfg.pretrained_name, map_location=device)  # type: ASRModel

    trainer = pl.Trainer(gpus=int(cfg.cuda))
    asr_model.set_trainer(trainer)
    asr_model = asr_model.eval()

    # Setup decoding strategy
    if hasattr(asr_model, 'change_decoding_strategy'):
        asr_model.change_decoding_strategy(cfg.rnnt_decoding)

    # load paths to audio
    filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}")))
    logging.info(f"\nTranscribing {len(filepaths)} files...\n")

    # setup AMP (optional)
    if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
        logging.info("AMP enabled!\n")
        autocast = torch.cuda.amp.autocast
    else:

        @contextlib.contextmanager
        def autocast():
            yield

    # transcribe audio
    with autocast():
        with torch.no_grad():
            transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size)
    logging.info(f"Finished transcribing {len(filepaths)} files !")

    logging.info(f"Writing transcriptions into file: {cfg.output_filename}")
    with open(cfg.output_filename, 'w', encoding='utf-8') as f:
        for line in transcriptions:
            f.write(f"{line}\n")

    logging.info("Finished writing predictions !")
Beispiel #7
0
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    if cfg.model_path is None and cfg.pretrained_name is None:
        raise ValueError(
            "Both cfg.model_path and cfg.pretrained_name cannot be None!")
    if cfg.audio_dir is None and cfg.dataset_manifest is None:
        raise ValueError(
            "Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

    # setup GPU
    if cfg.cuda is None:
        if torch.cuda.is_available():
            cfg.cuda = 0  # use 0th CUDA device
        else:
            cfg.cuda = -1  # use CPU

    device = torch.device(f'cuda:{cfg.cuda}' if cfg.cuda >= 0 else 'cpu')

    # setup model
    if cfg.model_path is not None:
        # restore model from .nemo file path
        model_cfg = ASRModel.restore_from(restore_path=cfg.model_path,
                                          return_config=True)
        classpath = model_cfg.target  # original class path
        imported_class = model_utils.import_class_by_path(
            classpath)  # type: ASRModel
        logging.info(f"Restoring model : {imported_class.__name__}")
        asr_model = imported_class.restore_from(
            restore_path=cfg.model_path, map_location=device)  # type: ASRModel
        model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
    else:
        # restore model by name
        asr_model = ASRModel.from_pretrained(
            model_name=cfg.pretrained_name,
            map_location=device)  # type: ASRModel
        model_name = cfg.pretrained_name

    trainer = pl.Trainer(gpus=[cfg.cuda] if cfg.cuda >= 0 else 0)
    asr_model.set_trainer(trainer)
    asr_model = asr_model.eval()

    # Setup decoding strategy
    if hasattr(asr_model, 'change_decoding_strategy'):
        asr_model.change_decoding_strategy(cfg.rnnt_decoding)

    # get audio filenames
    if cfg.audio_dir is not None:
        filepaths = list(
            glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}")))
    else:
        # get filenames from manifest
        filepaths = []
        with open(cfg.dataset_manifest, 'r') as f:
            for line in f:
                item = json.loads(line)
                filepaths.append(item['audio_filepath'])
    logging.info(f"\nTranscribing {len(filepaths)} files...\n")

    # setup AMP (optional)
    if cfg.amp and torch.cuda.is_available() and hasattr(
            torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
        logging.info("AMP enabled!\n")
        autocast = torch.cuda.amp.autocast
    else:

        @contextlib.contextmanager
        def autocast():
            yield

    # Compute output filename
    if cfg.output_filename is None:
        # create default output filename
        if cfg.audio_dir is not None:
            cfg.output_filename = os.path.dirname(
                os.path.join(cfg.audio_dir, '.')) + '.json'
        else:
            cfg.output_filename = cfg.dataset_manifest.replace(
                '.json', f'_{model_name}.json')

    # if transcripts should not be overwritten, and already exists, skip re-transcription step and return
    if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
        logging.info(
            f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
            f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
        )

        return cfg

    # transcribe audio
    with autocast():
        with torch.no_grad():
            transcriptions = asr_model.transcribe(filepaths,
                                                  batch_size=cfg.batch_size)
    logging.info(f"Finished transcribing {len(filepaths)} files !")

    logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

    # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
    if type(transcriptions) == tuple and len(transcriptions) == 2:
        transcriptions = transcriptions[0]

    # write audio transcriptions
    with open(cfg.output_filename, 'w', encoding='utf-8') as f:
        if cfg.audio_dir is not None:
            for idx, text in enumerate(transcriptions):
                item = {'audio_filepath': filepaths[idx], 'pred_text': text}
                f.write(json.dumps(item) + "\n")
        else:
            with open(cfg.dataset_manifest, 'r') as fr:
                for idx, line in enumerate(fr):
                    item = json.loads(line)
                    item['pred_text'] = transcriptions[idx]
                    f.write(json.dumps(item) + "\n")

    logging.info("Finished writing predictions !")
    return cfg
Beispiel #8
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--model_file",
                        type=str,
                        required=True,
                        help="Path to source .nemo file")
    parser.add_argument("--target_file",
                        type=str,
                        required=True,
                        help="Path to write target .nemo file")
    parser.add_argument("--tensor_model_parallel_size",
                        type=int,
                        required=True,
                        help="TP size of source model")
    parser.add_argument("--target_tensor_model_parallel_size",
                        type=int,
                        required=True,
                        help="TP size of target model")
    parser.add_argument(
        "--model_class",
        type=str,
        default=
        "nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel",
        help=
        "NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel",
    )
    parser.add_argument("--precision",
                        default=16,
                        help="PyTorch Lightning Trainer precision flag")

    args = parser.parse_args()

    precision = args.precision
    if args.precision in ["32", "16"]:
        precision = int(float(args.precision))
    tp_size = args.tensor_model_parallel_size
    tgt_tp_size = args.target_tensor_model_parallel_size
    cls = model_utils.import_class_by_path(args.model_class)

    trainer = Trainer(devices=1,
                      plugins=NLPDDPPlugin(),
                      accelerator="cpu",
                      precision=precision)
    app_state = AppState()
    app_state.data_parallel_rank = 0
    app_state.pipeline_model_parallel_size = 1  # not supported yet in this script
    app_state.tensor_model_parallel_size = tp_size
    app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size

    if tp_size > 1:
        partitions = []
        for i in range(tp_size):
            app_state.tensor_model_parallel_rank = i
            model = cls.restore_from(restore_path=args.model_file,
                                     trainer=trainer,
                                     map_location=torch.device("cpu"))
            params = [p for _, p in model.named_parameters()]
            partitions.append(params)
            # app_state is being updated incorrectly during restore
            app_state.data_parallel_rank = 0
            app_state.pipeline_model_parallel_size = 1  # not supported yet in this script
            app_state.tensor_model_parallel_size = tp_size
            app_state.model_parallel_size = (
                app_state.pipeline_model_parallel_size *
                app_state.tensor_model_parallel_size)

        model.cfg.tensor_model_parallel_size = 1
        app_state.model_parallel_size = 1
        trainer = Trainer(devices=1,
                          plugins=NLPDDPPlugin(),
                          accelerator="cpu",
                          precision=precision)
        model = cls(model.cfg, trainer).to('cpu')
        model._save_restore_connector = NLPSaveRestoreConnector()

        if tgt_tp_size > 1:
            merge_partition(model, partitions)
        else:
            merge_partition(model, partitions, args.target_file)
    else:
        app_state.model_parallel_size = 1
        model = cls.restore_from(restore_path=args.model_file, trainer=trainer)

    if tgt_tp_size > 1:
        partitions = []
        params = [p for _, p in model.named_parameters()]
        partitions.append(params)

        model.cfg.tensor_model_parallel_size = tgt_tp_size
        app_state.model_parallel_size = tgt_tp_size
        trainer = Trainer(devices=1,
                          plugins=NLPDDPPlugin(),
                          accelerator="cpu",
                          precision=precision)
        model = cls(model.cfg, trainer).to('cpu')
        model._save_restore_connector = NLPSaveRestoreConnector()

        split_partition(model, partitions, tgt_tp_size, args.target_file)

    logging.info("Successfully finished changing partitions!")
Beispiel #9
0
def main(cfg: TranscriptionConfig) -> TranscriptionConfig:
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    if is_dataclass(cfg):
        cfg = OmegaConf.structured(cfg)

    if cfg.model_path is None and cfg.pretrained_name is None:
        raise ValueError(
            "Both cfg.model_path and cfg.pretrained_name cannot be None!")
    if cfg.audio_dir is None and cfg.dataset_manifest is None:
        raise ValueError(
            "Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

    # setup GPU
    if cfg.cuda is None:
        if torch.cuda.is_available():
            device = [0]  # use 0th CUDA device
            accelerator = 'gpu'
        else:
            device = 1
            accelerator = 'cpu'
    else:
        device = [cfg.cuda]
        accelerator = 'gpu'

    map_location = torch.device('cuda:{}'.format(device[0]) if accelerator ==
                                'gpu' else 'cpu')

    # setup model
    if cfg.model_path is not None:
        # restore model from .nemo file path
        model_cfg = ASRModel.restore_from(restore_path=cfg.model_path,
                                          return_config=True)
        classpath = model_cfg.target  # original class path
        imported_class = model_utils.import_class_by_path(
            classpath)  # type: ASRModel
        logging.info(f"Restoring model : {imported_class.__name__}")
        asr_model = imported_class.restore_from(
            restore_path=cfg.model_path,
            map_location=map_location)  # type: ASRModel
        model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
    else:
        # restore model by name
        asr_model = ASRModel.from_pretrained(
            model_name=cfg.pretrained_name,
            map_location=map_location)  # type: ASRModel
        model_name = cfg.pretrained_name

    trainer = pl.Trainer(devices=device, accelerator=accelerator)
    asr_model.set_trainer(trainer)
    asr_model = asr_model.eval()
    partial_audio = False

    # Setup decoding strategy
    if hasattr(asr_model, 'change_decoding_strategy'):
        asr_model.change_decoding_strategy(cfg.rnnt_decoding)

    # get audio filenames
    if cfg.audio_dir is not None:
        filepaths = list(
            glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}")))
    else:
        # get filenames from manifest
        filepaths = []
        if os.stat(cfg.dataset_manifest).st_size == 0:
            logging.error(
                f"The input dataset_manifest {cfg.dataset_manifest} is empty. Exiting!"
            )
            return None

        with open(cfg.dataset_manifest, 'r') as f:
            has_two_fields = []
            for line in f:
                item = json.loads(line)
                if "offset" in item and "duration" in item:
                    has_two_fields.append(True)
                else:
                    has_two_fields.append(False)
                filepaths.append(item['audio_filepath'])
        partial_audio = all(has_two_fields)

    logging.info(f"\nTranscribing {len(filepaths)} files...\n")

    # setup AMP (optional)
    if cfg.amp and torch.cuda.is_available() and hasattr(
            torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
        logging.info("AMP enabled!\n")
        autocast = torch.cuda.amp.autocast
    else:

        @contextlib.contextmanager
        def autocast():
            yield

    # Compute output filename
    if cfg.output_filename is None:
        # create default output filename
        if cfg.audio_dir is not None:
            cfg.output_filename = os.path.dirname(
                os.path.join(cfg.audio_dir, '.')) + '.json'
        else:
            cfg.output_filename = cfg.dataset_manifest.replace(
                '.json', f'_{model_name}.json')

    # if transcripts should not be overwritten, and already exists, skip re-transcription step and return
    if not cfg.overwrite_transcripts and os.path.exists(cfg.output_filename):
        logging.info(
            f"Previous transcripts found at {cfg.output_filename}, and flag `overwrite_transcripts`"
            f"is {cfg.overwrite_transcripts}. Returning without re-transcribing text."
        )

        return cfg

    # transcribe audio
    with autocast():
        with torch.no_grad():
            if partial_audio:
                if isinstance(asr_model, EncDecCTCModel):
                    transcriptions = transcribe_partial_audio(
                        asr_model=asr_model,
                        path2manifest=cfg.dataset_manifest,
                        batch_size=cfg.batch_size,
                        num_workers=cfg.num_workers,
                    )
                else:
                    logging.warning(
                        "RNNT models do not support transcribe partial audio for now. Transcribing full audio."
                    )
                    transcriptions = asr_model.transcribe(
                        paths2audio_files=filepaths,
                        batch_size=cfg.batch_size,
                        num_workers=cfg.num_workers,
                    )
            else:
                transcriptions = asr_model.transcribe(
                    paths2audio_files=filepaths,
                    batch_size=cfg.batch_size,
                    num_workers=cfg.num_workers,
                )

    logging.info(f"Finished transcribing {len(filepaths)} files !")

    logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

    # if transcriptions form a tuple (from RNNT), extract just "best" hypothesis
    if type(transcriptions) == tuple and len(transcriptions) == 2:
        transcriptions = transcriptions[0]
    # write audio transcriptions
    with open(cfg.output_filename, 'w', encoding='utf-8') as f:
        if cfg.audio_dir is not None:
            for idx, text in enumerate(transcriptions):
                item = {'audio_filepath': filepaths[idx], 'pred_text': text}
                f.write(json.dumps(item) + "\n")
        else:
            with open(cfg.dataset_manifest, 'r') as fr:
                for idx, line in enumerate(fr):
                    item = json.loads(line)
                    item['pred_text'] = transcriptions[idx]
                    f.write(json.dumps(item) + "\n")

    logging.info("Finished writing predictions !")
    return cfg
Beispiel #10
0
    def from_config_dict(cls,
                         config: 'DictConfig',
                         trainer: Optional['Trainer'] = None):
        """Instantiates object using DictConfig-based configuration"""
        # Resolve the config dict
        if _HAS_HYDRA:
            if isinstance(config, DictConfig):
                config = OmegaConf.to_container(config, resolve=True)
                config = OmegaConf.create(config)
                OmegaConf.set_struct(config, True)

            config = maybe_update_config_version(config)

        # Hydra 0.x API
        if ('cls' in config
                or 'target' in config) and 'params' in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        # Hydra 1.x API
        elif '_target_' in config and _HAS_HYDRA:
            # regular hydra-based instantiation
            instance = hydra.utils.instantiate(config=config)
        else:
            instance = None
            prev_error = ""
            # Attempt class path resolution from config `target` class (if it exists)
            if 'target' in config:
                target_cls = config[
                    "target"]  # No guarantee that this is a omegaconf class
                imported_cls = None
                try:
                    # try to import the target class
                    imported_cls = import_class_by_path(target_cls)
                    # if calling class (cls) is subclass of imported class,
                    # use subclass instead
                    if issubclass(cls, imported_cls):
                        imported_cls = cls
                    accepts_trainer = Serialization._inspect_signature_for_trainer(
                        imported_cls)
                    if accepts_trainer:
                        instance = imported_cls(cfg=config, trainer=trainer)
                    else:
                        instance = imported_cls(cfg=config)
                except Exception as e:
                    # record previous error
                    tb = traceback.format_exc()
                    prev_error = f"Model instantiation failed!\nTarget class:\t{target_cls}" f"\nError(s):\t{e}\n{tb}"
                    logging.debug(prev_error + "\nFalling back to `cls`.")

            # target class resolution was unsuccessful, fall back to current `cls`
            if instance is None:
                try:
                    accepts_trainer = Serialization._inspect_signature_for_trainer(
                        cls)
                    if accepts_trainer:
                        instance = cls(cfg=config, trainer=trainer)
                    else:
                        instance = cls(cfg=config)

                except Exception as e:
                    # report saved errors, if any, and raise
                    if prev_error:
                        logging.error(prev_error)
                    raise e

        if not hasattr(instance, '_cfg'):
            instance._cfg = config
        return instance
Beispiel #11
0
def main(cfg: TranscriptionConfig):
    logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

    if cfg.model_path is None and cfg.pretrained_name is None:
        raise ValueError(
            "Both cfg.model_path and cfg.pretrained_name cannot be None!")
    if cfg.audio_dir is None and cfg.dataset_manifest is None:
        raise ValueError(
            "Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")

    # setup GPU
    if cfg.cuda is None:
        cfg.cuda = torch.cuda.is_available()

    if type(cfg.cuda) == int:
        device_id = int(cfg.cuda)
    else:
        device_id = 0

    device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu')

    # setup model
    if cfg.model_path is not None:
        # restore model from .nemo file path
        model_cfg = ASRModel.restore_from(restore_path=cfg.model_path,
                                          return_config=True)
        classpath = model_cfg.target  # original class path
        imported_class = model_utils.import_class_by_path(
            classpath)  # type: ASRModel
        logging.info(f"Restoring model : {imported_class.__name__}")
        asr_model = imported_class.restore_from(
            restore_path=cfg.model_path, map_location=device)  # type: ASRModel
        model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
    else:
        # restore model by name
        asr_model = ASRModel.from_pretrained(
            model_name=cfg.pretrained_name,
            map_location=device)  # type: ASRModel
        model_name = cfg.pretrained_name

    trainer = pl.Trainer(gpus=int(cfg.cuda))
    asr_model.set_trainer(trainer)
    asr_model = asr_model.eval()

    # Setup decoding strategy
    if hasattr(asr_model, 'change_decoding_strategy'):
        asr_model.change_decoding_strategy(cfg.rnnt_decoding)

    # get audio filenames
    if cfg.audio_dir is not None:
        filepaths = list(
            glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}")))
    else:
        # get filenames from manifest
        filepaths = []
        with open(cfg.dataset_manifest, 'r') as f:
            for line in f:
                item = json.loads(line)
                filepaths.append(item['audio_filepath'])
    logging.info(f"\nTranscribing {len(filepaths)} files...\n")

    # setup AMP (optional)
    if cfg.amp and torch.cuda.is_available() and hasattr(
            torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
        logging.info("AMP enabled!\n")
        autocast = torch.cuda.amp.autocast
    else:

        @contextlib.contextmanager
        def autocast():
            yield

    # transcribe audio
    with autocast():
        with torch.no_grad():
            transcriptions = asr_model.transcribe(filepaths,
                                                  batch_size=cfg.batch_size)
    logging.info(f"Finished transcribing {len(filepaths)} files !")

    if cfg.output_filename is None:
        # create default output filename
        if cfg.audio_dir is not None:
            cfg.output_filename = os.path.dirname(
                os.path.join(cfg.audio_dir, '.')) + '.json'
        else:
            cfg.output_filename = cfg.dataset_manifest.replace(
                '.json', f'_{model_name}.json')

    logging.info(f"Writing transcriptions into file: {cfg.output_filename}")

    with open(cfg.output_filename, 'w', encoding='utf-8') as f:
        if cfg.audio_dir is not None:
            for idx, text in enumerate(transcriptions):
                item = {'audio_filepath': filepaths[idx], 'pred_text': text}
                f.write(json.dumps(item) + "\n")
        else:
            with open(cfg.dataset_manifest, 'r') as fr:
                for idx, line in enumerate(fr):
                    item = json.loads(line)
                    item['pred_text'] = transcriptions[idx]
                    f.write(json.dumps(item) + "\n")

    logging.info("Finished writing predictions !")
Beispiel #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        'model_fname_list',
        metavar='N',
        type=str,
        nargs='+',
        help='Input .nemo files (or folders who contains them) to parse',
    )
    parser.add_argument(
        '--import_fname_list',
        type=str,
        nargs='+',
        default=[],
        help=
        'A list of Python file names to "from FILE import *" (Needed when some classes were defined in __main__ of a script)',
    )
    args = parser.parse_args()

    logging.info(
        f"\n\nIMPORTANT: Use --import_fname_list for all files that contain missing classes (AttributeError: Can't get attribute '???' on <module '__main__' from '???'>)\n\n"
    )

    for fn in args.import_fname_list:
        logging.info(f"Importing * from {fn}")
        sys.path.insert(0, os.path.dirname(fn))
        globals().update(
            importlib.import_module(os.path.splitext(
                os.path.basename(fn))[0]).__dict__)

    device = torch.device("cpu")

    # loop over all folders with .nemo files (or .nemo files)
    for model_fname_i, model_fname in enumerate(args.model_fname_list):
        if not model_fname.endswith(".nemo"):
            # assume model_fname is a folder which contains a .nemo file (filter .nemo files which matches with "*-averaged.nemo")
            nemo_files = list(
                filter(lambda fn: not fn.endswith("-averaged.nemo"),
                       glob.glob(os.path.join(model_fname, "*.nemo"))))
            if len(nemo_files) != 1:
                raise RuntimeError(
                    f"Expected only a single .nemo files but discovered {len(nemo_files)} .nemo files"
                )

            model_fname = nemo_files[0]

        model_folder_path = os.path.dirname(model_fname)
        fn, fe = os.path.splitext(model_fname)
        avg_model_fname = f"{fn}-averaged{fe}"

        logging.info(
            f"\n===> [{model_fname_i+1} / {len(args.model_fname_list)}] Parsing folder {model_folder_path}\n"
        )

        # restore model from .nemo file path
        model_cfg = ModelPT.restore_from(restore_path=model_fname,
                                         return_config=True)
        classpath = model_cfg.target  # original class path
        imported_class = model_utils.import_class_by_path(classpath)
        logging.info(f"Loading model {model_fname}")
        nemo_model = imported_class.restore_from(restore_path=model_fname,
                                                 map_location=device)

        # search for all checkpoints (ignore -last.ckpt)
        checkpoint_paths = [
            os.path.join(model_folder_path, x)
            for x in os.listdir(model_folder_path)
            if x.endswith('.ckpt') and not x.endswith('-last.ckpt')
        ]
        """ < Checkpoint Averaging Logic > """
        # load state dicts
        n = len(checkpoint_paths)
        avg_state = None

        logging.info(f"Averaging {n} checkpoints ...")

        for ix, path in enumerate(checkpoint_paths):
            checkpoint = torch.load(path, map_location=device)

            if 'state_dict' in checkpoint:
                checkpoint = checkpoint['state_dict']

            if ix == 0:
                # Initial state
                avg_state = checkpoint

                logging.info(
                    f"Initialized average state dict with checkpoint : {path}")
            else:
                # Accumulated state
                for k in avg_state:
                    avg_state[k] = avg_state[k] + checkpoint[k]

                logging.info(
                    f"Updated average state dict with state from checkpoint : {path}"
                )

        for k in avg_state:
            if str(avg_state[k].dtype).startswith("torch.int"):
                # For int type, not averaged, but only accumulated.
                # e.g. BatchNorm.num_batches_tracked
                pass
            else:
                avg_state[k] = avg_state[k] / n

        # restore merged weights into model
        nemo_model.load_state_dict(avg_state, strict=True)
        # Save model
        logging.info(f"Saving average mdel to: {avg_model_fname}")
        nemo_model.save_to(avg_model_fname)
Beispiel #13
0
if __name__ == "__main__":
    args = parser.parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    text_files = []
    if args.in_text:
        if args.model is None:
            raise ValueError(
                f"ASR model must be provided to extract vocabulary for text processing"
            )
        elif os.path.exists(args.model):
            model_cfg = ASRModel.restore_from(restore_path=args.model,
                                              return_config=True)
            classpath = model_cfg.target  # original class path
            imported_class = model_utils.import_class_by_path(
                classpath)  # type: ASRModel
            print(f"Restoring model : {imported_class.__name__}")
            asr_model = imported_class.restore_from(
                restore_path=args.model)  # type: ASRModel
            model_name = os.path.splitext(os.path.basename(args.model))[0]
        else:
            # restore model by name
            asr_model = ASRModel.from_pretrained(
                model_name=args.model)  # type: ASRModel
            model_name = args.model

        vocabulary = asr_model.cfg.decoder.vocabulary

        if os.path.isdir(args.in_text):
            text_files = Path(args.in_text).glob(("*.txt"))
        else: