def parse_config(args): configs = open_config(path=args.config_path) check_config_structure(configs=configs) models = {} for config in configs['model_config_list']: try: batch_size = config['config'].get('batch_size', None) model_ver_policy = config['config'].get('model_version_policy', None) model = ModelBuilder.build( model_name=config['config']['name'], model_directory=config['config']['base_path'], batch_size=batch_size, model_version_policy=model_ver_policy) models[config['config']['name']] = model except ValidationError as e_val: logger.warning("Model version policy for model {} is invalid. " "Exception: {}".format(config['config']['name'], e_val)) except Exception as e: logger.warning("Unexpected error occurred in {} model. " "Exception: {}".format(config['config']['name'], e)) if args.rest_port > 0: process_thread = threading.Thread(target=start_web_rest_server, args=[models, args.rest_port]) process_thread.setDaemon(True) process_thread.start() start_server(models=models, max_workers=1, port=args.port)
def parse_one_model(args): try: model_version_policy = json.loads(args.model_version_policy) model = ModelBuilder.build(model_name=args.model_name, model_directory=args.model_path, batch_size=args.batch_size, model_version_policy=model_version_policy) except ValidationError as e_val: logger.error("Model version policy is invalid. " "Exception: {}".format(e_val)) sys.exit() except json.decoder.JSONDecodeError as e_json: logger.error("model_version_policy field must be in json format. " "Exception: {}".format(e_json)) sys.exit() except Exception as e: logger.error("Unexpected error occurred. " "Exception: {}".format(e)) sys.exit() models = {args.model_name: model} if args.rest_port > 0: process_thread = threading.Thread(target=start_web_rest_server, args=[models, args.rest_port]) process_thread.setDaemon(True) process_thread.start() start_server(models=models, max_workers=1, port=args.port)
def parse_one_model(args): model = ModelBuilder.build(model_name=args.model_name, model_directory=args.model_path, batch_size=args.batch_size) start_server(models={args.model_name: model}, max_workers=1, port=args.port)
def parse_config(args): set_engine_requests_queue_size(args) configs = open_config(path=args.config_path) validate(configs, models_config_schema) models = {} for config in configs['model_config_list']: try: model_spec = get_model_spec(config['config']) model = ModelBuilder.build(**model_spec) if model is not None: models[config['config']['name']] = model except ValidationError as e_val: logger.warning("Model version policy or plugin config " "for model {} is invalid. " "Exception: {}".format(config['config']['name'], e_val)) except Exception as e: logger.warning("Unexpected error occurred in {} model. " "Exception: {}".format(config['config']['name'], e)) if not models: logger.info("Could not access any of provided models. Server will " "exit now.") sys.exit() if args.rest_port > 0: process_thread = threading.Thread( target=start_web_rest_server, args=[models, args.rest_port, args.rest_workers]) process_thread.setDaemon(True) process_thread.start() start_server(models=models, max_workers=args.grpc_workers, port=args.port)
def parse_one_model(args): try: args.model_version_policy = json.loads(args.model_version_policy) if args.shape is not None and args.batch_size is not None: logger.warning(CONFLICTING_PARAMS_WARNING.format(args.model_name)) args.batch_size = None model_spec = get_model_spec(vars(args)) model = ModelBuilder.build(**model_spec) except ValidationError as e_val: logger.error("Model version policy is invalid. " "Exception: {}".format(e_val)) sys.exit() except json.decoder.JSONDecodeError as e_json: logger.error("model_version_policy field must be in json format. " "Exception: {}".format(e_json)) sys.exit() except Exception as e: logger.error("Unexpected error occurred. " "Exception: {}".format(e)) sys.exit() models = {} if model is not None: models[args.model_name] = model else: logger.info("Could not access provided model. Server will exit now.") sys.exit() if args.rest_port > 0: process_thread = threading.Thread(target=start_web_rest_server, args=[models, args.rest_port]) process_thread.setDaemon(True) process_thread.start() start_server(models=models, max_workers=1, port=args.port)
def parse_config(args): configs = open_config(path=args.config_path) check_config_structure(configs=configs) models = {} for config in configs['model_config_list']: model = ModelBuilder.build( model_name=config['config']['name'], model_directory=config['config']['base_path']) models[config['config']['name']] = model start_server(models=models, max_workers=1, port=args.port)
def parse_one_model(args): try: model_version_policy = json.loads(args.model_version_policy) model = ModelBuilder.build(model_name=args.model_name, model_directory=args.model_path, batch_size=args.batch_size, model_version_policy=model_version_policy) except ValidationError as e_val: logger.error("Model version policy is invalid. " "Exception: {}".format(e_val)) sys.exit() except json.decoder.JSONDecodeError as e_json: logger.error("model_version_policy field must be in json format. " "Exception: {}".format(e_json)) sys.exit() except Exception as e: logger.error("Unexpected error occurred. " "Exception: {}".format(e)) sys.exit() start_server(models={args.model_name: model}, max_workers=1, port=args.port)
def parse_one_model(args): set_engine_requests_queue_size(args) try: args.model_version_policy = json.loads(args.model_version_policy) if args.plugin_config is not None: args.plugin_config = json.loads(args.plugin_config) model_spec = get_model_spec(vars(args)) model = ModelBuilder.build(**model_spec) except ValidationError as e_val: logger.error("Model version policy or plugin config is invalid. " "Exception: {}".format(e_val)) sys.exit() except json.decoder.JSONDecodeError as e_json: logger.error("model_version_policy and plugin_config fields must be " "in json format. " "Exception: {}".format(e_json)) sys.exit() except Exception as e: logger.error("Unexpected error occurred. " "Exception: {}".format(e)) sys.exit() models = {} if model is not None: models[args.model_name] = model else: logger.info("Could not access provided model. Server will exit now.") sys.exit() total_workers_number = args.grpc_workers if args.rest_port > 0: total_workers_number += args.rest_workers process_thread = threading.Thread( target=start_web_rest_server, args=[models, args.rest_port, args.rest_workers]) process_thread.setDaemon(True) process_thread.start() start_server(models=models, max_workers=args.grpc_workers, port=args.port)
def test_build_s3_model(mocker): s3_build_mocker = mocker.patch('ie_serving.models.s3_model.S3Model.build') ModelBuilder.build('model_name', 's3://bucket/model') assert s3_build_mocker.called
def test_build_gs_model(mocker): gs_build_mocker = mocker.patch('ie_serving.models.gs_model.GSModel.build') ModelBuilder.build('model_name', 'gs://bucket/model') assert gs_build_mocker.called
def test_build_local_model(mocker): local_build_mocker = mocker.patch( 'ie_serving.models.local_model.LocalModel.build') ModelBuilder.build('model_name', 'opt/bucket/model') assert local_build_mocker.called