コード例 #1
0
def main(args):
    parser = get_parser()
    args = parser.parse_args(args)

    # logging info
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    elif args.verbose == 2:
        logging.basicConfig(
            level=logging.DEBUG,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
    else:
        logging.basicConfig(
            level=logging.WARN,
            format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
        )
        logging.warning("Skip DEBUG/INFO messages")

    # check CUDA_VISIBLE_DEVICES
    if args.ngpu > 0:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is None:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
        elif args.ngpu != len(cvd.split(",")):
            logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
            sys.exit(1)

        # TODO(kamo): support of multiple GPUs
        if args.ngpu > 1:
            logging.error("The program only supports ngpu=1.")
            sys.exit(1)

    # display PYTHONPATH
    logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))

    # seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    logging.info("set random seed = %d" % args.seed)

    # recog
    logging.info("backend = " + args.backend)
    if args.backend == "pytorch":
        if args.num_spkrs == 1:
            from espnet.asr.pytorch_backend.asr import enhance
            enhance(args)
        else:
            from espnet.asr.pytorch_backend.asr_mix import enhance
            enhance(args)
    else:
        raise ValueError("Only pytorch is supported.")
コード例 #2
0
ファイル: asr_enhance.py プロジェクト: xqq2018rebuild/espnet
def main(args):
    parser = configargparse.ArgumentParser(
        config_file_parser_class=configargparse.YAMLConfigFileParser,
        formatter_class=configargparse.ArgumentDefaultsHelpFormatter)
    # general configuration
    parser.add('--config', is_config_file=True, help='config file path')
    parser.add(
        '--config2',
        is_config_file=True,
        help=
        'second config file path that overwrites the settings in `--config`.')
    parser.add(
        '--config3',
        is_config_file=True,
        help=
        'third config file path that overwrites the settings in `--config` and `--config2`.'
    )

    parser.add_argument('--ngpu', default=0, type=int, help='Number of GPUs')
    parser.add_argument('--backend',
                        default='chainer',
                        type=str,
                        choices=['chainer', 'pytorch'],
                        help='Backend library')
    parser.add_argument('--debugmode', default=1, type=int, help='Debugmode')
    parser.add_argument('--seed', default=1, type=int, help='Random seed')
    parser.add_argument('--verbose',
                        '-V',
                        default=1,
                        type=int,
                        help='Verbose option')
    parser.add_argument(
        '--batchsize',
        default=1,
        type=int,
        help='Batch size for beam search (0: means no batch processing)')
    parser.add_argument('--preprocess-conf',
                        type=str,
                        default=None,
                        help='The configuration file for the pre-processing')
    # task related
    parser.add_argument('--recog-json',
                        type=str,
                        help='Filename of recognition data (json)')
    # model (parameter) related
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        help='Model file parameters to read')
    parser.add_argument('--model-conf',
                        type=str,
                        default=None,
                        help='Model config file')

    # Outputs configuration
    parser.add_argument('--enh-wspecifier',
                        type=str,
                        default=None,
                        help='Specify the output way for enhanced speech.'
                        'e.g. ark,scp:outdir,wav.scp')
    parser.add_argument('--enh-filetype',
                        type=str,
                        default='sound',
                        choices=['mat', 'hdf5', 'sound.hdf5', 'sound'],
                        help='Specify the file format for enhanced speech. '
                        '"mat" is the matrix format in kaldi')
    parser.add_argument('--fs',
                        type=int,
                        default=16000,
                        help='The sample frequency')
    parser.add_argument('--keep-length',
                        type=strtobool,
                        default=True,
                        help='Adjust the output length to match '
                        'with the input for enhanced speech')
    parser.add_argument('--image-dir',
                        type=str,
                        default=None,
                        help='The directory saving the images.')
    parser.add_argument('--num-images',
                        type=int,
                        default=20,
                        help='The number of images files to be saved. '
                        'If negative, all samples are to be saved.')

    # IStft
    parser.add_argument('--apply-istft',
                        type=strtobool,
                        default=True,
                        help='Apply istft to the output from the network')
    parser.add_argument('--istft-win-length',
                        type=int,
                        default=512,
                        help='The window length for istft. '
                        'This option is ignored '
                        'if stft is found in the preprocess-conf')
    parser.add_argument('--istft-n-shift',
                        type=str,
                        default=256,
                        help='The window type for istft. '
                        'This option is ignored '
                        'if stft is found in the preprocess-conf')
    parser.add_argument('--istft-window',
                        type=str,
                        default='hann',
                        help='The window type for istft. '
                        'This option is ignored '
                        'if stft is found in the preprocess-conf')

    args = parser.parse_args(args)

    # logging info
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose == 2:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check CUDA_VISIBLE_DEVICES
    if args.ngpu > 0:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is None:
            logging.warning("CUDA_VISIBLE_DEVICES is not set.")
        elif args.ngpu != len(cvd.split(",")):
            logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
            sys.exit(1)

        # TODO(kamo): support of multiple GPUs
        if args.ngpu > 1:
            logging.error("The program only supports ngpu=1.")
            sys.exit(1)

    # display PYTHONPATH
    logging.info('python path = ' + os.environ.get('PYTHONPATH', '(None)'))

    # seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    logging.info('set random seed = %d' % args.seed)

    # recog
    logging.info('backend = ' + args.backend)
    if args.backend == "pytorch":
        enhance(args)
    else:
        raise ValueError("Only pytorch is supported.")