예제 #1
0
def main():
    parser = argparse.ArgumentParser()
    # general configuration
    parser.add_argument('--gpu',
                        '-g',
                        default='-1',
                        type=str,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--backend',
                        default='chainer',
                        type=str,
                        choices=['chainer', 'pytorch'],
                        help='Backend library')
    parser.add_argument('--debugmode', default=1, type=int, help='Debugmode')
    parser.add_argument('--seed', default=1, type=int, help='Random seed')
    parser.add_argument('--verbose',
                        '-V',
                        default=1,
                        type=int,
                        help='Verbose option')
    # task related
    parser.add_argument(
        '--recog-feat',
        type=str,
        required=True,
        help='Filename of recognition feature data (Kaldi scp)')
    parser.add_argument('--recog-label',
                        type=str,
                        required=True,
                        help='Filename of recognition label data (json)')
    parser.add_argument('--result-label',
                        type=str,
                        required=True,
                        help='Filename of result label data (json)')
    # model (parameter) related
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        help='Model file parameters to read')
    parser.add_argument('--model-conf',
                        type=str,
                        required=True,
                        help='Model config file')
    # search related
    parser.add_argument('--nbest',
                        type=int,
                        default=1,
                        help='Output N-best hypotheses')
    parser.add_argument('--beam-size', type=int, default=1, help='Beam size')
    parser.add_argument('--penalty',
                        default=0.0,
                        type=float,
                        help='Incertion penalty')
    parser.add_argument(
        '--maxlenratio',
        default=0.0,
        type=float,
        help='Input length ratio to obtain max output length.' +
        'If maxlenratio=0.0 (default), it uses a end-detect function' +
        'to automatically find maximum hypothesis lengths')
    parser.add_argument('--minlenratio',
                        default=0.0,
                        type=float,
                        help='Input length ratio to obtain min output length')
    parser.add_argument('--ctc-weight',
                        default=0.0,
                        type=float,
                        help='CTC weight in joint decoding')
    # rnnlm related
    parser.add_argument('--rnnlm',
                        type=str,
                        default=None,
                        help='RNNLM model file to read')
    parser.add_argument('--lm-weight',
                        default=0.1,
                        type=float,
                        help='RNNLM weight.')
    args = parser.parse_args()

    # logging info
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose == 2:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # display PYTHONPATH
    logging.info('python path = ' + os.environ['PYTHONPATH'])

    # seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    logging.info('set random seed = %d' % args.seed)

    # recog
    logging.info('backend = ' + args.backend)
    if args.backend == "chainer":
        from asr_chainer import recog
        recog(args)
    elif args.backend == "pytorch":
        from asr_pytorch import recog
        recog(args)
    else:
        raise ValueError("chainer and pytorch are only supported.")
예제 #2
0
파일: asr_recog.py 프로젝트: yubouf/espnet
def main():
    parser = argparse.ArgumentParser()
    # general configuration
    parser.add_argument('--ngpu', default=0, type=int, help='Number of GPUs')
    parser.add_argument('--backend',
                        default='chainer',
                        type=str,
                        choices=['chainer', 'pytorch'],
                        help='Backend library')
    parser.add_argument('--debugmode', default=1, type=int, help='Debugmode')
    parser.add_argument('--seed', default=1, type=int, help='Random seed')
    parser.add_argument('--verbose',
                        '-V',
                        default=1,
                        type=int,
                        help='Verbose option')
    # task related
    parser.add_argument('--recog-json',
                        type=str,
                        help='Filename of recognition data (json)')
    parser.add_argument('--result-label',
                        type=str,
                        required=True,
                        help='Filename of result label data (json)')
    # model (parameter) related
    parser.add_argument('--model',
                        type=str,
                        required=True,
                        help='Model file parameters to read')
    parser.add_argument('--model-conf',
                        type=str,
                        default=None,
                        help='Model config file')
    # search related
    parser.add_argument('--nbest',
                        type=int,
                        default=1,
                        help='Output N-best hypotheses')
    parser.add_argument('--beam-size', type=int, default=1, help='Beam size')
    parser.add_argument('--penalty',
                        default=0.0,
                        type=float,
                        help='Incertion penalty')
    parser.add_argument('--maxlenratio',
                        default=0.0,
                        type=float,
                        help="""Input length ratio to obtain max output length.
                        If maxlenratio=0.0 (default), it uses a end-detect function
                        to automatically find maximum hypothesis lengths""")
    parser.add_argument('--minlenratio',
                        default=0.0,
                        type=float,
                        help='Input length ratio to obtain min output length')
    parser.add_argument('--ctc-weight',
                        default=0.0,
                        type=float,
                        help='CTC weight in joint decoding')
    # rnnlm related
    parser.add_argument('--rnnlm',
                        type=str,
                        default=None,
                        help='RNNLM model file to read')
    parser.add_argument('--rnnlm-conf',
                        type=str,
                        default=None,
                        help='RNNLM model config file to read')
    parser.add_argument('--word-rnnlm',
                        type=str,
                        default=None,
                        help='Word RNNLM model file to read')
    parser.add_argument('--word-dict',
                        type=str,
                        default=None,
                        help='Word list to read')
    parser.add_argument('--lm-weight',
                        default=0.1,
                        type=float,
                        help='RNNLM weight.')
    args = parser.parse_args()

    # logging info
    if args.verbose == 1:
        logging.basicConfig(
            level=logging.INFO,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    elif args.verbose == 2:
        logging.basicConfig(
            level=logging.DEBUG,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
    else:
        logging.basicConfig(
            level=logging.WARN,
            format=
            "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s")
        logging.warning("Skip DEBUG/INFO messages")

    # check CUDA_VISIBLE_DEVICES
    if args.ngpu > 0:
        cvd = os.environ.get("CUDA_VISIBLE_DEVICES")
        if cvd is None:
            logging.warn("CUDA_VISIBLE_DEVICES is not set.")
        elif args.ngpu != len(cvd.split(",")):
            logging.error("#gpus is not matched with CUDA_VISIBLE_DEVICES.")
            sys.exit(1)

    # display PYTHONPATH
    logging.info('python path = ' + os.environ['PYTHONPATH'])

    # seed setting
    random.seed(args.seed)
    np.random.seed(args.seed)
    logging.info('set random seed = %d' % args.seed)

    # recog
    logging.info('backend = ' + args.backend)
    if args.backend == "chainer":
        from asr_chainer import recog
        recog(args)
    elif args.backend == "pytorch":
        from asr_pytorch import recog
        recog(args)
    else:
        raise ValueError("chainer and pytorch are only supported.")