示例#1
0
def main():
    args = parser.parse_args()

    pipeline_config_path = find_config(args.config_path)
    https = args.https
    ssl_key = args.key
    ssl_cert = args.cert

    if args.download or args.mode == 'download':
        deep_download(pipeline_config_path)

    multi_instance = args.multi_instance
    stateful = args.stateful

    start_epoch_num = args.start_epoch_num

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path, recursive=args.recursive, 
                                         start_epoch_num=start_epoch_num)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path, to_train=False, to_validate=False,
                                         start_epoch_num=start_epoch_num)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        token = args.token
        interact_model_by_telegram(pipeline_config_path, token)
    elif args.mode == 'interactmsbot':
        ms_id = args.ms_id
        ms_secret = args.ms_secret
        run_ms_bf_default_agent(model_config=pipeline_config_path,
                                app_id=ms_id,
                                app_secret=ms_secret,
                                multi_instance=multi_instance,
                                stateful=stateful,
                                port=args.port)
    elif args.mode == 'alexa':
        run_alexa_default_agent(model_config=pipeline_config_path,
                                multi_instance=multi_instance,
                                stateful=stateful,
                                port=args.port,
                                https=https,
                                ssl_key=ssl_key,
                                ssl_cert=ssl_cert)
    elif args.mode == 'riseapi':
        alice = args.api_mode == 'alice'
        if alice:
            start_alice_server(pipeline_config_path, https, ssl_key, ssl_cert, port=args.port)
        else:
            start_model_server(pipeline_config_path, https, ssl_key, ssl_cert, port=args.port)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size, args.file_path)
    elif args.mode == 'install':
        install_from_config(pipeline_config_path)
    elif args.mode == 'crossval':
        if args.folds < 2:
            log.error('Minimum number of Folds is 2')
        else:
            n_folds = args.folds
            calc_cv_score(pipeline_config_path, n_folds=n_folds, is_loo=False)
示例#2
0
def main():
    args = parser.parse_args()
    pipeline_config_path = args.config_path
    if not Path(pipeline_config_path).is_file():
        configs = [c for c in Path(__file__).parent.glob(f'configs/**/{pipeline_config_path}.json')
                   if str(c.with_suffix('')).endswith(pipeline_config_path)]  # a simple way to not allow * and ?
        if configs:
            log.info(f"Interpriting '{pipeline_config_path}' as '{configs[0]}'")
            pipeline_config_path = str(configs[0])

    token = args.token or os.getenv('TELEGRAM_TOKEN')

    if args.download or args.mode == 'download':
        deep_download(['-c', pipeline_config_path])

    if args.mode == 'train':
        train_model_from_config(pipeline_config_path)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        if not token:
            log.error('Token required: initiate -t param or TELEGRAM_BOT env var with Telegram bot token')
        else:
            interact_model_by_telegram(pipeline_config_path, token)
    elif args.mode == 'riseapi':
        start_model_server(pipeline_config_path)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size, args.file_path)
示例#3
0
def main():
    args = parser.parse_args()
    pipeline_config_path = find_config(args.config_path)
    if args.download or args.mode == 'download':
        deep_download(['-c', pipeline_config_path])
    token = args.token or os.getenv('TELEGRAM_TOKEN')

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path,
                                         to_train=False,
                                         to_validate=False)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        if not token:
            log.error(
                'Token required: initiate -t param or TELEGRAM_BOT env var with Telegram bot token'
            )
        else:
            interact_model_by_telegram(pipeline_config_path, token)
    elif args.mode == 'riseapi':
        start_model_server(pipeline_config_path)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size,
                          args.file_path)
示例#4
0
def main():
    args = parser.parse_args()
    pipeline_config_path = find_config(args.config_path)

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    log.info("use gpu id:" + args.gpu_id)

    if args.download or args.mode == 'download':
        deep_download(pipeline_config_path)

    multi_instance = args.multi_instance
    stateful = args.stateful

    start_epoch_num = args.start_epoch_num

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path,
                                         recursive=args.recursive,
                                         start_epoch_num=start_epoch_num)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path,
                                         to_train=False,
                                         to_validate=False,
                                         start_epoch_num=start_epoch_num)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        token = args.token
        interact_model_by_telegram(pipeline_config_path, token)
    elif args.mode == 'interactmsbot':
        ms_id = args.ms_id
        ms_secret = args.ms_secret
        run_ms_bf_default_agent(model_config=pipeline_config_path,
                                app_id=ms_id,
                                app_secret=ms_secret,
                                multi_instance=multi_instance,
                                stateful=stateful)
    elif args.mode == 'riseapi':
        alice = args.api_mode == 'alice'
        https = args.https
        ssl_key = args.key
        ssl_cert = args.cert
        if alice:
            start_alice_server(pipeline_config_path, https, ssl_key, ssl_cert)
        else:
            start_model_server(pipeline_config_path, https, ssl_key, ssl_cert)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size,
                          args.file_path)
    elif args.mode == 'install':
        install_from_config(pipeline_config_path)
    elif args.mode == 'crossval':
        if args.folds < 2:
            log.error('Minimum number of Folds is 2')
        else:
            n_folds = args.folds
            calc_cv_score(pipeline_config_path, n_folds=n_folds, is_loo=False)
示例#5
0
def main():
    args = parser.parse_args()
    pipeline_config_path = find_config(args.config_path)
    if args.download or args.mode == 'download':
        deep_download(['-c', pipeline_config_path])
    token = args.token or os.getenv('TELEGRAM_TOKEN')

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path, to_train=False, to_validate=False)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        if not token:
            log.error('Token required: initiate -t param or TELEGRAM_BOT env var with Telegram bot token')
        else:
            interact_model_by_telegram(pipeline_config_path, token)
    elif args.mode == 'riseapi':
        start_model_server(pipeline_config_path)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size, args.file_path)
    elif args.mode == 'install':
        install_from_config(pipeline_config_path)
示例#6
0
def main():
    args = parser.parse_args()

    pipeline_config_path = find_config(args.config_path)
    https = args.https
    ssl_key = args.key
    ssl_cert = args.cert

    if args.download or args.mode == 'download':
        deep_download(pipeline_config_path)

    multi_instance = args.multi_instance
    stateful = args.stateful

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path,
                                         recursive=args.recursive,
                                         start_epoch_num=args.start_epoch_num)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path,
                                         to_train=False,
                                         start_epoch_num=args.start_epoch_num)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        token = args.token
        interact_model_by_telegram(
            model_config=pipeline_config_path,
            token=token,
            default_skill_wrap=not args.no_default_skill)
    elif args.mode == 'interactmsbot':
        ms_id = args.ms_id
        ms_secret = args.ms_secret
        run_ms_bf_default_agent(model_config=pipeline_config_path,
                                app_id=ms_id,
                                app_secret=ms_secret,
                                multi_instance=multi_instance,
                                stateful=stateful,
                                port=args.port,
                                https=https,
                                ssl_key=ssl_key,
                                ssl_cert=ssl_cert,
                                default_skill_wrap=not args.no_default_skill)
    elif args.mode == 'alexa':
        run_alexa_default_agent(model_config=pipeline_config_path,
                                multi_instance=multi_instance,
                                stateful=stateful,
                                port=args.port,
                                https=https,
                                ssl_key=ssl_key,
                                ssl_cert=ssl_cert,
                                default_skill_wrap=not args.no_default_skill)
    elif args.mode == 'riseapi':
        alice = args.api_mode == 'alice'
        if alice:
            start_alice_server(pipeline_config_path,
                               https,
                               ssl_key,
                               ssl_cert,
                               port=args.port)
        else:
            start_model_server(pipeline_config_path,
                               https,
                               ssl_key,
                               ssl_cert,
                               port=args.port)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size,
                          args.file_path)
    elif args.mode == 'install':
        install_from_config(pipeline_config_path)
    elif args.mode == 'crossval':
        if args.folds < 2:
            log.error('Minimum number of Folds is 2')
        else:
            n_folds = args.folds
            calc_cv_score(pipeline_config_path, n_folds=n_folds, is_loo=False)
示例#7
0
def main():
    args = parser.parse_args()
    pipeline_config_path = find_config(args.config_path)

    if args.download or args.mode == 'download':
        deep_download(['-c', pipeline_config_path])
    token = args.token or os.getenv('TELEGRAM_TOKEN')

    ms_id = args.ms_id or os.getenv('MS_APP_ID')
    ms_secret = args.ms_secret or os.getenv('MS_APP_SECRET')

    multi_instance = args.multi_instance
    stateful = args.stateful

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path,
                                         to_train=False,
                                         to_validate=False)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'interactbot':
        if not token:
            log.error(
                'Token required: initiate -t param or TELEGRAM_BOT env var with Telegram bot token'
            )
        else:
            interact_model_by_telegram(pipeline_config_path, token)
    elif args.mode == 'interactmsbot':
        if not ms_id:
            log.error(
                'Microsoft Bot Framework app id required: initiate -i param '
                'or MS_APP_ID env var with Microsoft app id')
        elif not ms_secret:
            log.error(
                'Microsoft Bot Framework app secret required: initiate -s param '
                'or MS_APP_SECRET env var with Microsoft app secret')
        else:
            run_ms_bf_default_agent(model_config_path=pipeline_config_path,
                                    app_id=ms_id,
                                    app_secret=ms_secret,
                                    multi_instance=multi_instance,
                                    stateful=stateful)
    elif args.mode == 'riseapi':
        alice = args.api_mode == 'alice'
        https = args.https
        ssl_key = args.key
        ssl_cert = args.cert
        start_model_server(pipeline_config_path, alice, https, ssl_key,
                           ssl_cert)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size,
                          args.file_path)
    elif args.mode == 'install':
        install_from_config(pipeline_config_path)
    elif args.mode == 'crossval':
        if args.folds < 2:
            log.error('Minimum number of Folds is 2')
        else:
            n_folds = args.folds
            calc_cv_score(pipeline_config_path=pipeline_config_path,
                          n_folds=n_folds,
                          is_loo=False)
示例#8
0
def main():
    args = parser.parse_args()
    pipeline_config_path = find_config(args.config_path)

    if args.download or args.mode == 'download':
        deep_download(pipeline_config_path)

    if args.mode == 'train':
        train_evaluate_model_from_config(pipeline_config_path,
                                         recursive=args.recursive,
                                         start_epoch_num=args.start_epoch_num)
    elif args.mode == 'evaluate':
        train_evaluate_model_from_config(pipeline_config_path,
                                         to_train=False,
                                         start_epoch_num=args.start_epoch_num)
    elif args.mode == 'interact':
        interact_model(pipeline_config_path)
    elif args.mode == 'telegram':
        interact_model_by_telegram(model_config=pipeline_config_path,
                                   token=args.token)
    elif args.mode == 'msbot':
        start_ms_bf_server(model_config=pipeline_config_path,
                           app_id=args.ms_id,
                           app_secret=args.ms_secret,
                           port=args.port,
                           https=args.https,
                           ssl_key=args.key,
                           ssl_cert=args.cert)
    elif args.mode == 'alexa':
        start_alexa_server(model_config=pipeline_config_path,
                           port=args.port,
                           https=args.https,
                           ssl_key=args.key,
                           ssl_cert=args.cert)
    elif args.mode == 'alice':
        start_alice_server(model_config=pipeline_config_path,
                           port=args.port,
                           https=args.https,
                           ssl_key=args.key,
                           ssl_cert=args.cert)
    elif args.mode == 'riseapi':
        start_model_server(pipeline_config_path,
                           args.https,
                           args.key,
                           args.cert,
                           port=args.port)
    elif args.mode == 'risesocket':
        start_socket_server(pipeline_config_path,
                            args.socket_type,
                            port=args.port,
                            socket_file=args.socket_file)
    elif args.mode == 'agent-rabbit':
        start_rabbit_service(model_config=pipeline_config_path,
                             service_name=args.service_name,
                             agent_namespace=args.agent_namespace,
                             batch_size=args.batch_size,
                             utterance_lifetime_sec=args.utterance_lifetime,
                             rabbit_host=args.rabbit_host,
                             rabbit_port=args.rabbit_port,
                             rabbit_login=args.rabbit_login,
                             rabbit_password=args.rabbit_password,
                             rabbit_virtualhost=args.rabbit_virtualhost)
    elif args.mode == 'predict':
        predict_on_stream(pipeline_config_path, args.batch_size,
                          args.file_path)
    elif args.mode == 'install':
        install_from_config(pipeline_config_path)
    elif args.mode == 'crossval':
        if args.folds < 2:
            log.error('Minimum number of Folds is 2')
        else:
            calc_cv_score(pipeline_config_path,
                          n_folds=args.folds,
                          is_loo=False)