예제 #1
0
def parse_config(args):
    configs = open_config(path=args.config_path)
    check_config_structure(configs=configs)
    models = {}
    for config in configs['model_config_list']:
        try:
            batch_size = config['config'].get('batch_size', None)
            model_ver_policy = config['config'].get('model_version_policy',
                                                    None)
            model = ModelBuilder.build(
                model_name=config['config']['name'],
                model_directory=config['config']['base_path'],
                batch_size=batch_size,
                model_version_policy=model_ver_policy)
            models[config['config']['name']] = model
        except ValidationError as e_val:
            logger.warning("Model version policy for model {} is invalid. "
                           "Exception: {}".format(config['config']['name'],
                                                  e_val))
        except Exception as e:
            logger.warning("Unexpected error occurred in {} model. "
                           "Exception: {}".format(config['config']['name'], e))
    if args.rest_port > 0:
        process_thread = threading.Thread(target=start_web_rest_server,
                                          args=[models, args.rest_port])
        process_thread.setDaemon(True)
        process_thread.start()
    start_server(models=models, max_workers=1, port=args.port)
예제 #2
0
def parse_one_model(args):
    try:
        model_version_policy = json.loads(args.model_version_policy)
        model = ModelBuilder.build(model_name=args.model_name,
                                   model_directory=args.model_path,
                                   batch_size=args.batch_size,
                                   model_version_policy=model_version_policy)
    except ValidationError as e_val:
        logger.error("Model version policy is invalid. "
                     "Exception: {}".format(e_val))
        sys.exit()
    except json.decoder.JSONDecodeError as e_json:
        logger.error("model_version_policy field must be in json format. "
                     "Exception: {}".format(e_json))
        sys.exit()
    except Exception as e:
        logger.error("Unexpected error occurred. " "Exception: {}".format(e))
        sys.exit()
    models = {args.model_name: model}
    if args.rest_port > 0:
        process_thread = threading.Thread(target=start_web_rest_server,
                                          args=[models, args.rest_port])
        process_thread.setDaemon(True)
        process_thread.start()
    start_server(models=models, max_workers=1, port=args.port)
예제 #3
0
def parse_one_model(args):
    model = ModelBuilder.build(model_name=args.model_name,
                               model_directory=args.model_path,
                               batch_size=args.batch_size)
    start_server(models={args.model_name: model},
                 max_workers=1,
                 port=args.port)
예제 #4
0
def parse_config(args):
    set_engine_requests_queue_size(args)
    configs = open_config(path=args.config_path)
    validate(configs, models_config_schema)
    models = {}
    for config in configs['model_config_list']:
        try:
            model_spec = get_model_spec(config['config'])

            model = ModelBuilder.build(**model_spec)
            if model is not None:
                models[config['config']['name']] = model
        except ValidationError as e_val:
            logger.warning("Model version policy or plugin config "
                           "for model {} is invalid. "
                           "Exception: {}".format(config['config']['name'],
                                                  e_val))
        except Exception as e:
            logger.warning("Unexpected error occurred in {} model. "
                           "Exception: {}".format(config['config']['name'], e))
    if not models:
        logger.info("Could not access any of provided models. Server will "
                    "exit now.")
        sys.exit()
    if args.rest_port > 0:
        process_thread = threading.Thread(
            target=start_web_rest_server,
            args=[models, args.rest_port, args.rest_workers])
        process_thread.setDaemon(True)
        process_thread.start()
    start_server(models=models, max_workers=args.grpc_workers, port=args.port)
예제 #5
0
def parse_one_model(args):
    try:
        args.model_version_policy = json.loads(args.model_version_policy)

        if args.shape is not None and args.batch_size is not None:
            logger.warning(CONFLICTING_PARAMS_WARNING.format(args.model_name))
            args.batch_size = None

        model_spec = get_model_spec(vars(args))

        model = ModelBuilder.build(**model_spec)
    except ValidationError as e_val:
        logger.error("Model version policy is invalid. "
                     "Exception: {}".format(e_val))
        sys.exit()
    except json.decoder.JSONDecodeError as e_json:
        logger.error("model_version_policy field must be in json format. "
                     "Exception: {}".format(e_json))
        sys.exit()
    except Exception as e:
        logger.error("Unexpected error occurred. " "Exception: {}".format(e))
        sys.exit()
    models = {}
    if model is not None:
        models[args.model_name] = model
    else:
        logger.info("Could not access provided model. Server will exit now.")
        sys.exit()

    if args.rest_port > 0:
        process_thread = threading.Thread(target=start_web_rest_server,
                                          args=[models, args.rest_port])
        process_thread.setDaemon(True)
        process_thread.start()
    start_server(models=models, max_workers=1, port=args.port)
예제 #6
0
def parse_config(args):
    configs = open_config(path=args.config_path)
    check_config_structure(configs=configs)
    models = {}
    for config in configs['model_config_list']:
        model = ModelBuilder.build(
            model_name=config['config']['name'],
            model_directory=config['config']['base_path'])
        models[config['config']['name']] = model
    start_server(models=models, max_workers=1, port=args.port)
예제 #7
0
def parse_one_model(args):
    try:
        model_version_policy = json.loads(args.model_version_policy)
        model = ModelBuilder.build(model_name=args.model_name,
                                   model_directory=args.model_path,
                                   batch_size=args.batch_size,
                                   model_version_policy=model_version_policy)
    except ValidationError as e_val:
        logger.error("Model version policy is invalid. "
                     "Exception: {}".format(e_val))
        sys.exit()
    except json.decoder.JSONDecodeError as e_json:
        logger.error("model_version_policy field must be in json format. "
                     "Exception: {}".format(e_json))
        sys.exit()
    except Exception as e:
        logger.error("Unexpected error occurred. " "Exception: {}".format(e))
        sys.exit()
    start_server(models={args.model_name: model},
                 max_workers=1,
                 port=args.port)
예제 #8
0
def parse_one_model(args):
    set_engine_requests_queue_size(args)
    try:
        args.model_version_policy = json.loads(args.model_version_policy)
        if args.plugin_config is not None:
            args.plugin_config = json.loads(args.plugin_config)

        model_spec = get_model_spec(vars(args))

        model = ModelBuilder.build(**model_spec)
    except ValidationError as e_val:
        logger.error("Model version policy or plugin config is invalid. "
                     "Exception: {}".format(e_val))
        sys.exit()
    except json.decoder.JSONDecodeError as e_json:
        logger.error("model_version_policy and plugin_config fields must be "
                     "in json format. "
                     "Exception: {}".format(e_json))
        sys.exit()
    except Exception as e:
        logger.error("Unexpected error occurred. " "Exception: {}".format(e))
        sys.exit()
    models = {}
    if model is not None:
        models[args.model_name] = model
    else:
        logger.info("Could not access provided model. Server will exit now.")
        sys.exit()

    total_workers_number = args.grpc_workers
    if args.rest_port > 0:
        total_workers_number += args.rest_workers
        process_thread = threading.Thread(
            target=start_web_rest_server,
            args=[models, args.rest_port, args.rest_workers])
        process_thread.setDaemon(True)
        process_thread.start()
    start_server(models=models, max_workers=args.grpc_workers, port=args.port)
def test_build_s3_model(mocker):
    s3_build_mocker = mocker.patch('ie_serving.models.s3_model.S3Model.build')
    ModelBuilder.build('model_name', 's3://bucket/model')
    assert s3_build_mocker.called
def test_build_gs_model(mocker):
    gs_build_mocker = mocker.patch('ie_serving.models.gs_model.GSModel.build')
    ModelBuilder.build('model_name', 'gs://bucket/model')
    assert gs_build_mocker.called
def test_build_local_model(mocker):
    local_build_mocker = mocker.patch(
        'ie_serving.models.local_model.LocalModel.build')
    ModelBuilder.build('model_name', 'opt/bucket/model')
    assert local_build_mocker.called