def get_trained_model_state_dict(model_path):
    """Extract the trained model state dict for pre-initialization.

    Args:
        model_path (str): Path to model.***.best

    Return:
        model.state_dict() (OrderedDict): the loaded model state_dict
        (bool): Boolean defining whether the model is an LM

    """
    conf_path = os.path.join(os.path.dirname(model_path), "model.json")
    if "rnnlm" in model_path:
        logging.warning("reading model parameters from %s", model_path)

        return get_lm_state_dict(torch.load(model_path))

    idim, odim, args = get_model_conf(model_path, conf_path)

    logging.warning("reading model parameters from " + model_path)

    if hasattr(args, "model_module"):
        model_module = args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, args)
    torch_load(model_path, model)
    assert (isinstance(model, MTInterface) or isinstance(model, ASRInterface)
            or isinstance(model, TTSInterface))

    return model.state_dict()
예제 #2
0
    def __init__(self, conffile=None):
        if conffile is not None:
            if isinstance(conffile, dict):
                self.conf = copy.deepcopy(conffile)
            else:
                with io.open(conffile, encoding='utf-8') as f:
                    self.conf = yaml.safe_load(f)
                    assert isinstance(self.conf, dict), type(self.conf)
        else:
            self.conf = {'mode': 'sequential', 'process': []}

        self.functions = OrderedDict()
        if self.conf.get('mode', 'sequential') == 'sequential':
            for idx, process in enumerate(self.conf['process']):
                assert isinstance(process, dict), type(process)
                opts = dict(process)
                process_type = opts.pop('type')
                class_obj = dynamic_import(process_type, import_alias)
                # TODO(karita): assert issubclass(class_obj, TransformInterface)
                check_kwargs(class_obj, opts)
                try:
                    self.functions[idx] = class_obj(**opts)
                except TypeError:
                    try:
                        signa = signature(class_obj)
                    except ValueError:
                        # Some function, e.g. built-in function, are failed
                        pass
                    else:
                        logging.error('Expected signature: {}({})'.format(
                            class_obj.__name__, signa))
                    raise
        else:
            raise NotImplementedError('Not supporting mode={}'.format(
                self.conf['mode']))
예제 #3
0
def dynamic_import_optimizer(name: str,
                             backend: str) -> OptimizerFactoryInterface:
    """Import optimizer class dynamically.

    Args:
        name (str): alias name or dynamic import syntax `module:class`
        backend (str): backend name e.g., chainer or pytorch

    Returns:
        OptimizerFactoryInterface or FunctionalOptimizerAdaptor

    """
    if backend == "pytorch":
        from espnet.optimizer.pytorch import OPTIMIZER_FACTORY_DICT

        return OPTIMIZER_FACTORY_DICT[name]
    elif backend == "chainer":
        from espnet.optimizer.chainer import OPTIMIZER_FACTORY_DICT

        return OPTIMIZER_FACTORY_DICT[name]
    else:
        raise NotImplementedError(f"unsupported backend: {backend}")

    factory_class = dynamic_import(name)
    assert issubclass(factory_class, OptimizerFactoryInterface)
    return factory_class
예제 #4
0
def load_trained_model(model_path):
    """Load the trained model.

    Args:
        model_path(str): Path to model.***.best

    """
    # read training config
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), 'model.json'))

    # load trained model parameters
    logging.info('reading model parameters from ' + model_path)
    # To be compatible with v.0.3.0 models
    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr_ftransformer:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)
    #torch_load(model_path, model)
    model = torch.load(model_path).cpu()
#    for p1, p2 in zip(model.parameters(), mdl.parameters()):
#      p1 = p2
    
    return model, train_args
예제 #5
0
    def __init__(self, conffile=None):
        if conffile is not None:
            if isinstance(conffile, dict):
                self.conf = copy.deepcopy(conffile)
            else:
                with io.open(conffile, encoding="utf-8") as f:
                    self.conf = yaml.safe_load(f)
                    assert isinstance(self.conf, dict), type(self.conf)
        else:
            self.conf = {"mode": "sequential", "process": []}

        self.functions = OrderedDict()
        if self.conf.get("mode", "sequential") == "sequential":
            for idx, process in enumerate(self.conf["process"]):
                assert isinstance(process, dict), type(process)
                opts = dict(process)
                process_type = opts.pop("type")
                class_obj = dynamic_import(process_type, import_alias)
                # TODO(karita): assert issubclass(class_obj, TransformInterface)
                try:
                    self.functions[idx] = class_obj(**opts)
                except TypeError:
                    try:
                        signa = signature(class_obj)
                    except ValueError:
                        # Some function, e.g. built-in function, are failed
                        pass
                    else:
                        logging.error("Expected signature: {}({})".format(
                            class_obj.__name__, signa))
                    raise
        else:
            raise NotImplementedError("Not supporting mode={}".format(
                self.conf["mode"]))
예제 #6
0
    def __init__(self, device='cpu'):
        dict_path = "downloads/data/lang_1char/train_no_dev_units.txt"
        model_path = "downloads/exp/train_no_dev_pytorch_train_pytorch_tacotron2.v3/results/model.last1.avg.best"
        vocoder_path = "downloads/ljspeech.parallel_wavegan.v1/checkpoint-400000steps.pkl"
        vocoder_conf = "downloads/ljspeech.parallel_wavegan.v1/config.yml"

        device = torch.device(device)

        idim, odim, train_args = get_model_conf(model_path)
        model_class = dynamic_import(train_args.model_module)
        model = model_class(idim, odim, train_args)
        torch_load(model_path, model)
        model = model.eval().to(device)
        inference_args = Namespace(**{"threshold": 0.5, "minlenratio": 0.0, "maxlenratio": 10.0})

        with open(vocoder_conf) as f:
            config = yaml.load(f, Loader=yaml.Loader)
        vocoder = ParallelWaveGANGenerator(**config["generator_params"])
        vocoder.load_state_dict(torch.load(vocoder_path, map_location="cpu")["model"]["generator"])
        vocoder.remove_weight_norm()
        vocoder = vocoder.eval().to(device)

        with open(dict_path) as f:
            lines = f.readlines()
        lines = [line.replace("\n", "").split(" ") for line in lines]
        char_to_id = {c: int(i) for c, i in lines}

        self.device = device
        self.char_to_id = char_to_id
        self.idim = idim
        self.model = model
        self.inference_args = inference_args
        self.config = config
        self.vocoder = vocoder
예제 #7
0
def decode(args):
    """Decode with E2E-TTS model."""
    set_deterministic_pytorch(args)
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # show arguments
    for key in sorted(vars(args).keys()):
        logging.info('args: ' + key + ': ' + str(vars(args)[key]))

    # define model
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    assert isinstance(model, TTSInterface)
    logging.info(model)

    # load trained model parameters
    logging.info('reading model parameters from ' + args.model)
    torch_load(args.model, model)
    model.eval()

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    model = model.to(device)

    # generate speaker embeddings
    SequenceGenerator(model.resnet_spkid, device, args.feat_scp, args.out_file)
예제 #8
0
def load_trained_model(model_path):
    """Load the trained model.

    Args:
        model_path(str): Path to model.***.best

    """
    # read training config
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), 'model.json'))

    # load trained model parameters
    logging.info('reading model parameters from ' + model_path)
    # To be compatible with v.0.3.0 models
    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)

    if hasattr(train_args,
               "slu_model") and train_args.slu_model and train_args.slu_loss:
        if hasattr(train_args, 'slu_pooling'):
            model.add_slu(train_args.slu_model, train_args.slu_loss,
                          train_args.slu_tune_weights, train_args.slu_pooling)
        else:
            model.add_slu(train_args.slu_model, train_args.slu_loss,
                          train_args.slu_tune_weights, '')

    torch_load(model_path, model)

    return model, train_args
예제 #9
0
파일: asr_init.py 프로젝트: zouwei02/espnet
def get_trained_model_state_dict(model_path):
    """Extract the trained model state dict for pre-initialization.

    Args:
        model_path (str): Path to model.***.best

    Return:
        model.state_dict() (OrderedDict): the loaded model state_dict
        (str): Type of model. Either ASR/MT or LM.

    """
    conf_path = os.path.join(os.path.dirname(model_path), 'model.json')
    if 'rnnlm' in model_path:
        logging.warning('reading model parameters from %s', model_path)

        return torch.load(model_path), 'lm'

    idim, odim, args = get_model_conf(model_path, conf_path)

    logging.warning('reading model parameters from ' + model_path)

    if hasattr(args, "model_module"):
        model_module = args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, args)
    torch_load(model_path, model)
    assert isinstance(model, MTInterface) or isinstance(model, ASRInterface)

    return model.state_dict(), 'asr-mt'
예제 #10
0
    def __init__(self, config):

        logging.info("Worker2 init start:")

        self.tmp_dir = config['tmp']
        os.makedirs(self.tmp_dir, exist_ok=True)
        self.debug_mode = config['debug_mode']

        self.phone2id = frontend.LoadDictionary(config['dict_path'])

        self.idim, odim, train_args = get_model_conf(
            config['acoustic_model_path'])

        model_class = dynamic_import(train_args.model_module)
        model = model_class(self.idim, odim, train_args)
        torch_load(config['acoustic_model_path'], model)

        logging.info("Model Load Done {}".format(
            config['acoustic_model_path']))

        self.model = model.eval().to(device)
        self.inference_args = Namespace(
            **{
                "threshold": config['threshold'],
                "minlenratio": config['minlenratio'],
                "maxlenratio": config['maxlenratio']
            })

        logging.info("TacotronLPCNetWorker Init Done")

        self.num_chunk_frame = config["num_chunk_frame"]
        logging.info("Chunk size: {}".format(self.num_chunk_frame))
        self.overlap = config["overlap"]
        logging.info("Overlap: {}".format(self.overlap))
def load_trained_model(model_path, training=True):
    """Load the trained model for recognition.

    Args:
        model_path (str): Path to model.***.best

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json"))

    logging.warning("reading model parameters from " + model_path)

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    # CTC Loss is not needed, default to builtin to prevent import errors
    if hasattr(train_args, "ctc_type"):
        train_args.ctc_type = "builtin"

    model_class = dynamic_import(model_module)

    if "transducer" in model_module:
        model = model_class(idim, odim, train_args, training)
    else:
        model = model_class(idim, odim, train_args)
    torch_load(model_path, model)

    return model, train_args
예제 #12
0
파일: tts_train.py 프로젝트: zqs01/espnet
def main(cmd_args):
    parser = get_parser()
    args, _ = parser.parse_known_args(cmd_args)

    from espnet.utils.dynamic_import import dynamic_import
    model_class = dynamic_import(args.model_module)
    assert issubclass(model_class, TTSInterface)
    model_class.add_arguments(parser)
    args = parser.parse_args(cmd_args)

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            '%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
        logging.warning('Skip DEBUG/INFO messages')

    # check CUDA_VISIBLE_DEVICES
    if args.ngpu > 0:
        # python 2 case
        if platform.python_version_tuple()[0] == '2':
            if "clsp.jhu.edu" in subprocess.check_output(["hostname", "-f"]):
                cvd = subprocess.check_output(
                    ["/usr/local/bin/free-gpu", "-n",
                     str(args.ngpu)]).strip()
                logging.info('CLSP: use gpu' + cvd)
                os.environ['CUDA_VISIBLE_DEVICES'] = cvd
        # python 3 case
        else:
            if "clsp.jhu.edu" in subprocess.check_output(["hostname",
                                                          "-f"]).decode():
                cvd = subprocess.check_output(
                    ["/usr/local/bin/free-gpu", "-n",
                     str(args.ngpu)]).decode().strip()
                logging.info('CLSP: use gpu' + cvd)
                os.environ['CUDA_VISIBLE_DEVICES'] = cvd

        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is None:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
        elif args.ngpu != len(cvd.split(",")):
            logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
            sys.exit(1)

    # set random seed
    logging.info('random seed = %d' % args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    if args.backend == "pytorch":
        from espnet.tts.pytorch_backend.tts import train
        train(args)
    else:
        raise NotImplementedError("Only pytorch is supported.")
예제 #13
0
def main(cmd_args):
    """Run training."""
    parser = get_parser()
    args, _ = parser.parse_known_args(cmd_args)

    from espnet.utils.dynamic_import import dynamic_import

    model_class = dynamic_import(args.model_module)
    assert issubclass(model_class, TTSInterface)
    model_class.add_arguments(parser)
    args = parser.parse_args(cmd_args)

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # If --ngpu is not given,
    #   1. if CUDA_VISIBLE_DEVICES is set, all visible devices
    #   2. if nvidia-smi exists, use all devices
    #   3. else ngpu=0
    if args.ngpu is None:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is not None:
            ngpu = len(cvd.split(","))
        else:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
            try:
                p = subprocess.run(
                    ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
                )
            except (subprocess.CalledProcessError, FileNotFoundError):
                ngpu = 0
            else:
                ngpu = len(p.stderr.decode().split("\n")) - 1
        args.ngpu = ngpu
    else:
        ngpu = args.ngpu
    logging.info(f"ngpu: {ngpu}")

    # set random seed
    logging.info("random seed = %d" % args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    if args.backend == "pytorch":
        from espnet.tts.pytorch_backend.tts import train

        train(args)
    else:
        raise NotImplementedError("Only pytorch is supported.")
예제 #14
0
def viterbi_decode(args):
    set_deterministic_pytorch(args)
    idim, odim, train_args = get_model_conf(
        args.model, os.path.join(os.path.dirname(args.model), 'model.json'))
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    if args.model is not None:
        load_params = dict(torch.load(args.model))
        if 'model' in load_params:
            load_params = dict(load_params['model'])
        if 'state_dict' in load_params:
            load_params = dict(load_params['state_dict'])
        model_params = dict(model.named_parameters())
        for k, v in load_params.items():
            k = k.replace('module.', '')
            if k in model_params and v.size() == model_params[k].size():
                model_params[k].data = v.data
                logging.warning('load parameters {}'.format(k))
    model.recog_args = args

    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info('gpu id: ' + str(gpu_id))
        model.cuda()

    with open(args.recog_json, 'rb') as f:
        js = json.load(f)['utts']
    new_js = {}

    load_inputs_and_targets = LoadInputsAndTargets(
        mode='asr',
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=train_args.preprocess_conf
        if args.preprocess_conf is None else args.preprocess_conf,
        preprocess_args={'train': False})

    with torch.no_grad():
        for idx, name in enumerate(js.keys(), 1):
            logging.info('(%d/%d) decoding ' + name, idx, len(js.keys()))
            batch = [(name, js[name])]
            feat = load_inputs_and_targets(batch)
            y = np.fromiter(map(int,
                                batch[0][1]['output'][0]['tokenid'].split()),
                            dtype=np.int64)
            align = model.viterbi_decode(feat[0][0], y)
            assert len(align) == len(y)
            new_js[name] = js[name]
            new_js[name]['output'][0]['align'] = ' '.join(
                [str(i) for i in list(align)])

    with open(args.result_label, 'wb') as f:
        f.write(
            json.dumps({
                'utts': new_js
            },
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode('utf_8'))
def load_trained_modules(idim, odim, args, interface=ASRInterface):
    """Load model encoder or/and decoder modules with ESPNET pre-trained model(s).

    Args:
        idim (int): initial input dimension.
        odim (int): initial output dimension.
        args (Namespace): The initial model arguments.
        interface (Interface): ASRInterface or STInterface or TTSInterface.

    Return:
        model (torch.nn.Module): The model with pretrained modules.

    """
    def print_new_keys(state_dict, modules, model_path):
        logging.warning("loading %s from model: %s", modules, model_path)

        for k in state_dict.keys():
            logging.warning("override %s" % k)

    enc_model_path = args.enc_init
    dec_model_path = args.dec_init
    enc_modules = args.enc_init_mods
    dec_modules = args.dec_init_mods

    model_class = dynamic_import(args.model_module)
    main_model = model_class(idim, odim, args)
    assert isinstance(main_model, interface)

    main_state_dict = main_model.state_dict()

    logging.warning("model(s) found for pre-initialization")
    for model_path, modules in [
        (enc_model_path, enc_modules),
        (dec_model_path, dec_modules),
    ]:
        if model_path is not None:
            if os.path.isfile(model_path):
                model_state_dict = get_trained_model_state_dict(model_path)

                modules = filter_modules(model_state_dict, modules)

                partial_state_dict = get_partial_state_dict(
                    model_state_dict, modules)

                if partial_state_dict:
                    if transfer_verification(main_state_dict,
                                             partial_state_dict, modules):
                        print_new_keys(partial_state_dict, modules, model_path)
                        main_state_dict.update(partial_state_dict)
                    else:
                        logging.warning(
                            f"modules {modules} in model {model_path} "
                            f"don't match your training config", )
            else:
                logging.warning("model was not found : %s", model_path)

    main_model.load_state_dict(main_state_dict)

    return main_model
예제 #16
0
def load_trained_modules(idim, odim, args, interface=ASRInterface):
    """Load ASR/MT/TTS model with pre-trained weights for specified modules.

    Args:
        idim (int): Input dimension.
        odim (int): Output dimension.
        args Namespace: Model arguments.
        interface (ASRInterface|MTInterface|TTSInterface): Model interface.

    Return:
        main_model (torch.nn.Module): Model with pre-initialized weights.

    """
    def print_new_keys(state_dict, modules, model_path):
        logging.info(f"Loading {modules} from model: {model_path}")

        for k in state_dict.keys():
            logging.warning(f"Overriding module {k}")

    enc_model_path = args.enc_init
    dec_model_path = args.dec_init
    enc_modules = args.enc_init_mods
    dec_modules = args.dec_init_mods

    model_class = dynamic_import(args.model_module)
    main_model = model_class(idim, odim, args)
    assert isinstance(main_model, interface)

    main_state_dict = main_model.state_dict()
    logging.warning("Model(s) found for pre-initialization.")

    for model_path, modules in [
        (enc_model_path, enc_modules),
        (dec_model_path, dec_modules),
    ]:
        if model_path is not None:
            if os.path.isfile(model_path):
                model_state_dict = get_trained_model_state_dict(
                    model_path, "transducer" in args.model_module)
                modules = filter_modules(model_state_dict, modules)

                partial_state_dict = get_partial_state_dict(
                    model_state_dict, modules)

                if partial_state_dict:
                    if transfer_verification(main_state_dict,
                                             partial_state_dict, modules):
                        print_new_keys(partial_state_dict, modules, model_path)
                        main_state_dict.update(partial_state_dict)
            else:
                logging.error(f"Specified model was not found: {model_path}")
                exit(1)

    main_model.load_state_dict(main_state_dict)

    return main_model
예제 #17
0
def load_trained_modules(idim, odim, args):
    """Load model encoder or/and decoder modules with ESPNET pre-trained model(s).

    Args:
        idim (int): initial input dimension.
        odim (int): initial output dimension.
        args (namespace): The initial model arguments.

    Return:
        model (torch.nn.Module): The model with pretrained modules.

    """
    enc_model_path = args.enc_init
    dec_model_path = args.dec_init
    enc_modules = args.enc_init_mods
    dec_modules = args.dec_init_mods

    model_class = dynamic_import(args.model_module)
    main_model = model_class(idim, odim, args)
    assert isinstance(main_model, ASRInterface)

    main_state_dict = main_model.state_dict()

    logging.info('model(s) found for pre-initialization')
    for model_path, modules in [(enc_model_path, enc_modules),
                                (dec_model_path, dec_modules)]:
        if model_path is not None:
            if os.path.isfile(model_path):
                model_state_dict, mode = get_trained_model_state_dict(
                    model_path)

                modules = filter_modules(model_state_dict, modules)
                if mode == 'lm':
                    partial_state_dict, modules = get_partial_lm_state_dict(
                        model_state_dict, modules)
                else:
                    partial_state_dict = get_partial_asr_mt_state_dict(
                        model_state_dict, modules)

                if partial_state_dict:
                    if transfer_verification(main_state_dict,
                                             partial_state_dict, modules):
                        logging.info('loading %s from model: %s', modules,
                                     model_path)
                        main_state_dict.update(partial_state_dict)
                    else:
                        logging.info(
                            'modules %s in model %s don\'t match your training config',
                            modules, model_path)
            else:
                logging.info('model was not found : %s', model_path)

    main_model.load_state_dict(main_state_dict)

    return main_model
예제 #18
0
def dynamic_import_lm(module, backend):
    """Import LM class dynamically.

    Args:
        module (str): module_name:class_name or alias in `predefined_lms`
        backend (str): NN backend. e.g., pytorch, chainer

    Returns:
        type: LM class

    """
    model_class = dynamic_import(module, predefined_lms.get(backend, dict()))
    assert issubclass(model_class, LMInterface), f"{module} does not implement LMInterface"
    return model_class
예제 #19
0
def dynamic_import_scheduler(module):
    """Import Scheduler class dynamically.

    Args:
        module (str): module_name:class_name or alias in `SCHEDULER_DICT`

    Returns:
        type: Scheduler class

    """
    model_class = dynamic_import(module, SCHEDULER_DICT)
    assert issubclass(
        model_class,
        SchedulerInterface), f"{module} does not implement SchedulerInterface"
    return model_class
예제 #20
0
    def __init__(self):
        # TODO Move this into a config File, give option of different models
        self.trans_type = "phn"
        dict_path = "/home/ntrusse2/espnet/downloads/en/fastspeech/data/lang_1phn/phn_train_no_dev_units.txt"
        model_path = "/home/ntrusse2/espnet/downloads/en/fastspeech/exp/phn_train_no_dev_pytorch_train_tacotron2.v3_fastspeech.v4.single/results/model.last1.avg.best"
        vocoder_path = "/home/ntrusse2/espnet/downloads/en/parallel_wavegan/ljspeech.parallel_wavegan.v2/checkpoint-400000steps.pkl"
        vocoder_conf = "/home/ntrusse2/espnet/downloads/en/parallel_wavegan/ljspeech.parallel_wavegan.v2/config.yml"

        # Copied right out of the examples on ESPNETs DEMO
        self.device = torch.device("cuda")
        print("Loading Torch Model...")
        self.idim, odim, train_args = get_model_conf(model_path)
        model_class = dynamic_import(train_args.model_module)
        model = model_class(self.idim, odim, train_args)
        torch_load(model_path, model)
        self.model = model.eval().to(self.device)
        self.inference_args = Namespace(
            **{
                "threshold": 0.5,
                "minlenratio": 0.0,
                "maxlenratio": 10.0,
                "use_attention_constraint": True,
                "backward_window": 1,
                "forward_window": 3,
            })

        print("Loading Vocoder...")
        with open(vocoder_conf) as f:
            self.config = yaml.load(f, Loader=yaml.Loader)
        vocoder_class = self.config.get("generator_type",
                                        "ParallelWaveGANGenerator")
        vocoder = getattr(parallel_wavegan.models,
                          vocoder_class)(**self.config["generator_params"])
        vocoder.load_state_dict(
            torch.load(vocoder_path, map_location="cpu")["model"]["generator"])
        vocoder.remove_weight_norm()
        self.vocoder = vocoder.eval().to(self.device)

        print("Loading Text Frontend...")
        with open(dict_path) as f:
            lines = f.readlines()
        lines = [line.replace("\n", "").split(" ") for line in lines]
        self.char_to_id = {c: int(i) for c, i in lines}
        self.g2p = G2p()

        self.pad_fn = torch.nn.ReplicationPad1d(
            self.config["generator_params"].get("aux_context_window", 0))
        self.use_noise_input = vocoder_class == "ParallelWaveGANGenerator"
예제 #21
0
파일: unsup_ar.py 프로젝트: j-pong/HYnet
def recog(args):
    set_deterministic_pytorch(args)

    from espnet.asr.asr_utils import get_model_conf, torch_load
    # read training config
    idim, odim, train_args = get_model_conf(args.model, args.model_conf)

    # load trained model parameters
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    torch_load(args.model, model)
    # model, train_args = load_trained_model(args.model)
    model.recog_args = args

    # gpu
    if args.ngpu == 1:
        gpu_id = list(range(args.ngpu))
        logging.info("gpu id: " + str(gpu_id))
        model.cuda()

    # read json data
    with open(args.recog_json, "rb") as f:
        js = json.load(f)["utts"]

    load_inputs_and_targets = LoadInputsAndTargets(
        mode="asr",
        load_output=False,
        sort_in_input_length=False,
        preprocess_conf=None,
        preprocess_args={"train": False},
    )

    ark_file = open(args.result_ark, 'wb')
    if args.batchsize == 0:
        with torch.no_grad():
            for idx, name in enumerate(js.keys(), 1):
                logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
                batch = [(name, js[name])]

                feat = load_inputs_and_targets(batch)
                feat = (feat[0][0])

                hyps = model.recognize(feat)
                hyps = hyps.squeeze(0)
                hyps = hyps.data.numpy()

                write_mat(ark_file, hyps, key=name)
예제 #22
0
def perform_tts(input_text):
    idim, odim, train_args = get_model_conf(model_path)
    model_class = dynamic_import(train_args.model_module)
    model = model_class(idim, odim, train_args)
    torch_load(model_path, model)
    model = model.eval().to(device)
    inference_args = Namespace(
        **{
            "threshold": 0.5,
            "minlenratio": 0.0,
            "maxlenratio": 10.0,
            # Only for Tacotron 2
            "use_attention_constraint": True,
            "backward_window": 1,
            "forward_window": 3,
            # Only for fastspeech (lower than 1.0 is faster speech, higher than 1.0 is slower speech)
            "fastspeech_alpha": 1.0,
        })

    # define neural vocoder
    fs = 22050
    vocoder = load_model(vocoder_path)
    vocoder.remove_weight_norm()
    vocoder = vocoder.eval().to(device)

    # define text frontend

    with open(dict_path) as f:
        lines = f.readlines()
    lines = [line.replace("\n", "").split(" ") for line in lines]
    char_to_id = {c: int(i) for c, i in lines}
    g2p = G2p()

    print('input : ', input_text)
    with torch.no_grad():
        start = time.time()
        x = frontend(input_text, g2p, char_to_id, idim)
        c, _, _ = model.inference(x, inference_args)
        y = vocoder.inference(c)
    rtf = (time.time() - start) / (len(y) / fs)
    print(f"RTF = {rtf:5f}")
    print(y)
    write("static/test.wav", fs, y.view(-1).cpu().numpy())
예제 #23
0
    def _load_teacher_model(self, model_path):
        # get teacher model config
        idim, odim, args = get_model_conf(model_path)

        # assert dimension is the same between teacher and studnet
        assert idim == self.idim
        assert odim == self.odim
        assert args.reduction_factor == self.reduction_factor

        # load teacher model
        from espnet.utils.dynamic_import import dynamic_import
        model_class = dynamic_import(args.model_module)
        model = model_class(idim, odim, args)
        torch_load(model_path, model)

        # freeze teacher model parameters
        for p in model.parameters():
            p.requires_grad = False

        return model
예제 #24
0
def load_trained_model(model_path):
    """Load the trained model for recognition.

    Args:
        model_path(str): Path to model.***.best

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), 'model.json'))

    logging.info('reading model parameters from ' + model_path)

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, train_args)
    torch_load(model_path, model)

    return model, train_args
예제 #25
0
def get_trained_model_state_dict(model_path, new_is_transducer):
    """Extract the trained model state dict for pre-initialization.

    Args:
        model_path (str): Path to trained model.
        new_is_transducer (bool): Whether the new model is Transducer-based.

    Return:
        (Dict): Trained model state dict.

    """
    logging.info(f"Reading model parameters from {model_path}")

    conf_path = os.path.join(os.path.dirname(model_path), "model.json")

    if "rnnlm" in model_path:
        return get_lm_state_dict(torch.load(model_path))

    idim, odim, args = get_model_conf(model_path, conf_path)

    if hasattr(args, "model_module"):
        model_module = args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    model_class = dynamic_import(model_module)
    model = model_class(idim, odim, args)
    torch_load(model_path, model)

    assert (isinstance(model, MTInterface) or isinstance(model, ASRInterface)
            or isinstance(model, TTSInterface))

    if new_is_transducer and "transducer" not in args.model_module:
        return create_transducer_compatible_state_dict(
            model.state_dict(),
            args.etype,
            args.eunits,
        )

    return model.state_dict()
예제 #26
0
def load_trained_model(model_path, use_ema=True, training=True):
    """Load the trained model for recognition.

    Args:
        model_path (str): Path to model.***.best
        use_ema (bool): Use EMA parameters when available

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json"))

    logging.warning("reading model parameters from " + model_path)

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
    # CTC Loss is not needed, default to builtin to prevent import errors
    if hasattr(train_args, "ctc_type"):
        train_args.ctc_type = "builtin"

    model_class = dynamic_import(model_module)

    if "transducer" in model_module:
        model = model_class(idim, odim, train_args, training=training)
        loading_fn = custom_torch_load
    else:
        model = model_class(idim, odim, train_args)
        loading_fn = torch_load

    if hasattr(train_args, "ema_decay") and train_args.ema_decay > 0:
        model = EMA(model, train_args.ema_decay)
    loading_fn(model_path, model)

    if use_ema and hasattr(model, "shadow"):
        logging.warning("use EMA parameters")
        model = model.shadow

    return model, train_args
예제 #27
0
def dynamic_import_optimizer(name: str, backend: str) -> type:
    """Import optimizer class dynamically.

    Args:
        name (str): alias name or dynamic import syntax `module:class`
        backend (str): backend name e.g., chainer or pytorch

    Returns:
        OptimizerAdaptorInterface or FunctionalOptimizerAdaptor

    """
    if name in OPTIMIZER_PARSER_DICT:
        if backend == "pytorch":
            from espnet.optimizer.pytorch import OPTIMIZER_BUILDER_DICT
        elif backend == "chainer":
            from espnet.optimizer.chainer import OPTIMIZER_BUILDER_DICT
        else:
            raise NotImplementedError(f"unsupported backend: {backend}")
        return FunctionalOptimizerAdaptor(OPTIMIZER_BUILDER_DICT[name],
                                          OPTIMIZER_PARSER_DICT[name])
    adaptor_class = dynamic_import(name)
    assert issubclass(adaptor_class, OptimizerAdaptorInterface)
    return adaptor_class
예제 #28
0
def load_trained_model(model_path, training=True):
    """Load the trained model for recognition.

    Args:
        model_path (str): Path to model.***.best
        training (bool): Training mode specification for transducer model.

    Returns:
        model (torch.nn.Module): Trained model.
        train_args (Namespace): Trained model arguments.

    """
    idim, odim, train_args = get_model_conf(
        model_path, os.path.join(os.path.dirname(model_path), "model.json"))

    logging.info(f"Reading model parameters from {model_path}")

    if hasattr(train_args, "model_module"):
        model_module = train_args.model_module
    else:
        model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"

    # CTC Loss is not needed, default to builtin to prevent import errors
    if hasattr(train_args, "ctc_type"):
        train_args.ctc_type = "builtin"

    model_class = dynamic_import(model_module)

    if "transducer" in model_module:
        model = model_class(idim, odim, train_args, training=training)
        custom_torch_load(model_path, model, training=training)
    else:
        model = model_class(idim, odim, train_args)
        torch_load(model_path, model)

    return model, train_args
예제 #29
0
def train(args):
    """Train with the given args.

    Args:
        args (namespace): The program arguments.

    """
    set_deterministic_pytorch(args)

    # check cuda availability
    if not torch.cuda.is_available():
        logging.warning("cuda is not available")

    # get input and output dimension info
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]
    utts = list(valid_json.keys())
    idim = int(valid_json[utts[0]]["output"][1]["shape"][1])
    odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
    logging.info("#input dims : " + str(idim))
    logging.info("#output dims: " + str(odim))

    # specify model architecture
    model_class = dynamic_import(args.model_module)
    model = model_class(idim, odim, args)
    assert isinstance(model, MTInterface)

    # write model config
    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)
    model_conf = args.outdir + "/model.json"
    with open(model_conf, "wb") as f:
        logging.info("writing a model config file to " + model_conf)
        f.write(
            json.dumps((idim, odim, vars(args)),
                       indent=4,
                       ensure_ascii=False,
                       sort_keys=True).encode("utf_8"))
    for key in sorted(vars(args).keys()):
        logging.info("ARGS: " + key + ": " + str(vars(args)[key]))

    reporter = model.reporter

    # check the use of multi-gpu
    if args.ngpu > 1:
        if args.batch_size != 0:
            logging.warning(
                "batch size is automatically increased (%d -> %d)" %
                (args.batch_size, args.batch_size * args.ngpu))
            args.batch_size *= args.ngpu

    # set torch device
    device = torch.device("cuda" if args.ngpu > 0 else "cpu")
    if args.train_dtype in ("float16", "float32", "float64"):
        dtype = getattr(torch, args.train_dtype)
    else:
        dtype = torch.float32
    model = model.to(device=device, dtype=dtype)

    logging.warning(
        "num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
            sum(p.numel() for p in model.parameters()),
            sum(p.numel() for p in model.parameters() if p.requires_grad),
            sum(p.numel() for p in model.parameters() if p.requires_grad) *
            100.0 / sum(p.numel() for p in model.parameters()),
        ))

    # Setup an optimizer
    if args.opt == "adadelta":
        optimizer = torch.optim.Adadelta(model.parameters(),
                                         rho=0.95,
                                         eps=args.eps,
                                         weight_decay=args.weight_decay)
    elif args.opt == "adam":
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)
    elif args.opt == "noam":
        from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt

        optimizer = get_std_opt(
            model.parameters(),
            args.adim,
            args.transformer_warmup_steps,
            args.transformer_lr,
        )
    else:
        raise NotImplementedError("unknown optimizer: " + args.opt)

    # setup apex.amp
    if args.train_dtype in ("O0", "O1", "O2", "O3"):
        try:
            from apex import amp
        except ImportError as e:
            logging.error(
                f"You need to install apex for --train-dtype {args.train_dtype}. "
                "See https://github.com/NVIDIA/apex#linux")
            raise e
        if args.opt == "noam":
            model, optimizer.optimizer = amp.initialize(
                model, optimizer.optimizer, opt_level=args.train_dtype)
        else:
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.train_dtype)
        use_apex = True
    else:
        use_apex = False

    # FIXME: TOO DIRTY HACK
    setattr(optimizer, "target", reporter)
    setattr(optimizer, "serialize", lambda s: reporter.serialize(s))

    # Setup a converter
    converter = CustomConverter()

    # read json data
    with open(args.train_json, "rb") as f:
        train_json = json.load(f)["utts"]
    with open(args.valid_json, "rb") as f:
        valid_json = json.load(f)["utts"]

    use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
    # make minibatch list (variable length)
    train = make_batchset(
        train_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        shortest_first=use_sortagrad,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )
    valid = make_batchset(
        valid_json,
        args.batch_size,
        args.maxlen_in,
        args.maxlen_out,
        args.minibatches,
        min_batch_size=args.ngpu if args.ngpu > 1 else 1,
        count=args.batch_count,
        batch_bins=args.batch_bins,
        batch_frames_in=args.batch_frames_in,
        batch_frames_out=args.batch_frames_out,
        batch_frames_inout=args.batch_frames_inout,
        mt=True,
        iaxis=1,
        oaxis=0,
    )

    load_tr = LoadInputsAndTargets(mode="mt", load_output=True)
    load_cv = LoadInputsAndTargets(mode="mt", load_output=True)
    # hack to make batchsize argument as 1
    # actual bathsize is included in a list
    # default collate function converts numpy array to pytorch tensor
    # we used an empty collate function instead which returns list
    train_iter = ChainerDataLoader(
        dataset=TransformDataset(train,
                                 lambda data: converter([load_tr(data)])),
        batch_size=1,
        num_workers=args.n_iter_processes,
        shuffle=not use_sortagrad,
        collate_fn=lambda x: x[0],
    )
    valid_iter = ChainerDataLoader(
        dataset=TransformDataset(valid,
                                 lambda data: converter([load_cv(data)])),
        batch_size=1,
        shuffle=False,
        collate_fn=lambda x: x[0],
        num_workers=args.n_iter_processes,
    )

    # Set up a trainer
    updater = CustomUpdater(
        model,
        args.grad_clip,
        {"main": train_iter},
        optimizer,
        device,
        args.ngpu,
        False,
        args.accum_grad,
        use_apex=use_apex,
    )
    trainer = training.Trainer(updater, (args.epochs, "epoch"),
                               out=args.outdir)

    if use_sortagrad:
        trainer.extend(
            ShufflingEnabler([train_iter]),
            trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs,
                     "epoch"),
        )

    # Resume from a snapshot
    if args.resume:
        logging.info("resumed from %s" % args.resume)
        torch_resume(args.resume, trainer)

    # Evaluate the model with the test dataset for each epoch
    if args.save_interval_iters > 0:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(
            CustomEvaluator(model, {"main": valid_iter}, reporter, device,
                            args.ngpu))

    # Save attention weight each epoch
    if args.num_save_attention > 0:
        # NOTE: sort it by output lengths
        data = sorted(
            list(valid_json.items())[:args.num_save_attention],
            key=lambda x: int(x[1]["output"][0]["shape"][0]),
            reverse=True,
        )
        if hasattr(model, "module"):
            att_vis_fn = model.module.calculate_all_attentions
            plot_class = model.module.attention_plot_class
        else:
            att_vis_fn = model.calculate_all_attentions
            plot_class = model.attention_plot_class
        att_reporter = plot_class(
            att_vis_fn,
            data,
            args.outdir + "/att_ws",
            converter=converter,
            transform=load_cv,
            device=device,
            ikey="output",
            iaxis=1,
        )
        trainer.extend(att_reporter, trigger=(1, "epoch"))
    else:
        att_reporter = None

    # Make a plot for training and validation values
    trainer.extend(
        extensions.PlotReport(["main/loss", "validation/main/loss"],
                              "epoch",
                              file_name="loss.png"))
    trainer.extend(
        extensions.PlotReport(["main/acc", "validation/main/acc"],
                              "epoch",
                              file_name="acc.png"))
    trainer.extend(
        extensions.PlotReport(["main/ppl", "validation/main/ppl"],
                              "epoch",
                              file_name="ppl.png"))
    trainer.extend(
        extensions.PlotReport(["main/bleu", "validation/main/bleu"],
                              "epoch",
                              file_name="bleu.png"))

    # Save best models
    trainer.extend(
        snapshot_object(model, "model.loss.best"),
        trigger=training.triggers.MinValueTrigger("validation/main/loss"),
    )
    trainer.extend(
        snapshot_object(model, "model.acc.best"),
        trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
    )

    # save snapshot which contains model and optimizer states
    if args.save_interval_iters > 0:
        trainer.extend(
            torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
            trigger=(args.save_interval_iters, "iteration"),
        )
    else:
        trainer.extend(torch_snapshot(), trigger=(1, "epoch"))

    # epsilon decay in the optimizer
    if args.opt == "adadelta":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adadelta_eps_decay(args.eps_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
    elif args.opt == "adam":
        if args.criterion == "acc":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.acc.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/acc",
                    lambda best_value, current_value: best_value >
                    current_value,
                ),
            )
        elif args.criterion == "loss":
            trainer.extend(
                restore_snapshot(model,
                                 args.outdir + "/model.loss.best",
                                 load_fn=torch_load),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )
            trainer.extend(
                adam_lr_decay(args.lr_decay),
                trigger=CompareValueTrigger(
                    "validation/main/loss",
                    lambda best_value, current_value: best_value <
                    current_value,
                ),
            )

    # Write a log of evaluation statistics for each epoch
    trainer.extend(
        extensions.LogReport(trigger=(args.report_interval_iters,
                                      "iteration")))
    report_keys = [
        "epoch",
        "iteration",
        "main/loss",
        "validation/main/loss",
        "main/acc",
        "validation/main/acc",
        "main/ppl",
        "validation/main/ppl",
        "elapsed_time",
    ]
    if args.opt == "adadelta":
        trainer.extend(
            extensions.observe_value(
                "eps",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["eps"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("eps")
    elif args.opt in ["adam", "noam"]:
        trainer.extend(
            extensions.observe_value(
                "lr",
                lambda trainer: trainer.updater.get_optimizer("main").
                param_groups[0]["lr"],
            ),
            trigger=(args.report_interval_iters, "iteration"),
        )
        report_keys.append("lr")
    if args.report_bleu:
        report_keys.append("main/bleu")
        report_keys.append("validation/main/bleu")
    trainer.extend(
        extensions.PrintReport(report_keys),
        trigger=(args.report_interval_iters, "iteration"),
    )

    trainer.extend(
        extensions.ProgressBar(update_interval=args.report_interval_iters))
    set_early_stop(trainer, args)

    if args.tensorboard_dir is not None and args.tensorboard_dir != "":
        from torch.utils.tensorboard import SummaryWriter

        trainer.extend(
            TensorboardLogger(SummaryWriter(args.tensorboard_dir),
                              att_reporter),
            trigger=(args.report_interval_iters, "iteration"),
        )
    # Run the training
    trainer.run()
    check_early_stop(trainer, args.epochs)
예제 #30
0
def main(cmd_args):
    """Run the main training function."""
    parser = get_parser()
    args, _ = parser.parse_known_args(cmd_args)
    if args.backend == "chainer" and args.train_dtype != "float32":
        raise NotImplementedError(
            f"chainer backend does not support --train-dtype {args.train_dtype}."
            "Use --dtype float32."
        )
    if args.ngpu == 0 and args.train_dtype in ("O0", "O1", "O2", "O3", "float16"):
        raise ValueError(
            f"--train-dtype {args.train_dtype} does not support the CPU backend."
        )

    from espnet.utils.dynamic_import import dynamic_import

    if args.model_module is None:
        model_module = "espnet.nets." + args.backend + "_backend.e2e_asr:E2E"
    else:
        model_module = args.model_module
    model_class = dynamic_import(model_module)
    model_class.add_arguments(parser)

    args = parser.parse_args(cmd_args)
    args.model_module = model_module
    if "chainer_backend" in args.model_module:
        args.backend = "chainer"
    if "pytorch_backend" in args.model_module:
        args.backend = "pytorch"

    # add version info in args
    args.version = __version__

    # logging info
    if args.verbose > 0:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # If --ngpu is not given,
    #   1. if CUDA_VISIBLE_DEVICES is set, all visible devices
    #   2. if nvidia-smi exists, use all devices
    #   3. else ngpu=0
    if args.ngpu is None:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is not None:
            ngpu = len(cvd.split(","))
        else:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
            try:
                p = subprocess.run(
                    ["nvidia-smi", "-L"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
                )
            except (subprocess.CalledProcessError, FileNotFoundError):
                ngpu = 0
            else:
                ngpu = len(p.stderr.decode().split("\n")) - 1
    else:
        if is_torch_1_2_plus and args.ngpu != 1:
            logging.debug(
                "There are some bugs with multi-GPU processing in PyTorch 1.2+"
                + " (see https://github.com/pytorch/pytorch/issues/21108)"
            )
        ngpu = args.ngpu
    logging.info(f"ngpu: {ngpu}")

    # display PYTHONPATH
    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    # set random seed
    logging.info("random seed = %d" % args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # load dictionary for debug log
    if args.dict is not None:
        with open(args.dict, "rb") as f:
            dictionary = f.readlines()
        char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
        char_list.insert(0, "<blank>")
        char_list.append("<eos>")
        # for non-autoregressive maskctc model
        if "maskctc" in args.model_module:
            char_list.append("<mask>")
        args.char_list = char_list
    else:
        args.char_list = None

    # train
    logging.info("backend = " + args.backend)

    if args.num_spkrs == 1:
        if args.backend == "chainer":
            from espnet.asr.chainer_backend.asr import train

            train(args)
        elif args.backend == "pytorch":
            from espnet.asr.pytorch_backend.asr import train

            train(args)
        else:
            raise ValueError("Only chainer and pytorch are supported.")
    else:
        # FIXME(kamo): Support --model-module
        if args.backend == "pytorch":
            from espnet.asr.pytorch_backend.asr_mix import train

            train(args)
        else:
            raise ValueError("Only pytorch is supported.")