def test_onnx_fails_fast(): models = {'onnx': 's3://bucket/prefix/whatever.onnx'} with pytest.raises(ValueError) as e: ModelLoader.load(models) assert 'Convert ONNX model' in str(e.value)
def test_invalid_model_path_input(): """ Test to ensure that folder being created is removed, when path is invalid """ models = {'squeezenet_v1': 'invalid_model_file_path.model'} with pytest.raises(Exception) as e: ModelLoader.load(models) assert not os.path.exists('invalid_model_file_path')
def _arg_process(self): """Process arguments before starting service or create application. """ try: # Port self.port = int(self.args.port) if self.args.port else 8080 self.host = self.args.host or '127.0.0.1' # Load models models = ModelLoader.load(self.args.models) # Register user defined model service or default mxnet_vision_service manifest = models[0][3] service_file = os.path.join(models[0][2], manifest['Model']['Service']) class_defs = self.serving_frontend.register_module( self.args.service or service_file) class_defs = list( filter(lambda c: len(c.__subclasses__()) == 0, class_defs)) if len(class_defs) != 1: raise Exception( 'There should be one user defined service derived from ModelService.' ) model_class_name = class_defs[0].__name__ # Load models using registered model definitions registered_models = self.serving_frontend.get_registered_modelservices( ) ModelClassDef = registered_models[model_class_name] self.serving_frontend.load_models(models, ModelClassDef, self.gpu) if len(self.args.models) > 5: raise Exception('Model number exceeds our system limits: 5') # Setup endpoint openapi_endpoints = self.serving_frontend.setup_openapi_endpoints( self.host, self.port) # Generate client SDK if self.args.gen_api is not None: ClientSDKGenerator.generate(openapi_endpoints, self.args.gen_api) # Generate metrics to target location (log, csv ...), default to log MetricsManager.start(self.args.metrics_write_to, self.args.models, Lock()) except Exception as e: logger.error('Failed to process arguments: ' + str(e)) exit(1)
def _arg_process(self): """Process arguments before starting service or create application. """ try: # Port self.port = int(self.args.port) if self.args.port else 8080 self.host = self.args.host or '127.0.0.1' # Load models models = ModelLoader.load(self.args.models) # Register user defined model service or default mxnet_vision_service manifest = models[0][3] service_file = os.path.join(models[0][2], manifest['Model']['Service']) class_defs = self.serving_frontend.register_module(self.args.service or service_file) class_defs = list(filter(lambda c: len(c.__subclasses__()) == 0, class_defs)) if len(class_defs) != 1: raise Exception('There should be one user defined service derived from ModelService.') model_class_name = class_defs[0].__name__ # Load models using registered model definitions registered_models = self.serving_frontend.get_registered_modelservices() ModelClassDef = registered_models[model_class_name] self.serving_frontend.load_models(models, ModelClassDef, self.gpu) if len(self.args.models) > 5: raise Exception('Model number exceeds our system limits: 5') # Setup endpoint openapi_endpoints = self.serving_frontend.setup_openapi_endpoints(self.host, self.port) # Generate client SDK if self.args.gen_api is not None: ClientSDKGenerator.generate(openapi_endpoints, self.args.gen_api) # Generate metrics to target location (log, csv ...), default to log MetricsManager.start(self.args.metrics_write_to, self.args.models, Lock()) except Exception as e: logger.error('Failed to process arguments: ' + str(e)) exit(1)
def start_serving(app_name='mms', args=None): """Start service routing. Parameters ---------- app_name : str App name to initialize mms service. args : List of str Arguments for starting service. By default it is None and commandline arguments will be used. It should follow the format recognized by python argparse parse_args method: https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.parse_args. An example for mms arguments: ['--models', 'resnet-18=path1', 'inception_v3=path2', '--gen-api', 'java', '--port', '8080'] """ # Parse the given arguments arguments = ArgParser.extract_args(args) # Download and/or Extract the given model files models = ModelLoader.load(arguments.models) # Instantiate an MMS object and prepare to start serving mms = MMS(app_name, args=arguments, models=models) mms.start_model_serving()
from mms.serving_frontend import ServingFrontend from mms.model_loader import ModelLoader logger = logging.getLogger() logger.setLevel(logging.INFO) serving_frontend = ServingFrontend(__name__) model_path = 'squeezenet_v1.1.model' if os.environ.get('LAMBDA_TASK_ROOT', False): shutil.copyfile('/var/task/squeezenet_v1.1.model', '/tmp/squeezenet_v1.1.model') model_path = '/tmp/squeezenet_v1.1.model' models = ModelLoader.load({'squeezenet': model_path}) manifest = models[0][3] service_file = os.path.join(models[0][2], manifest['Model']['Service']) class_defs = serving_frontend.register_module(service_file) if len(class_defs) < 1: raise Exception('User defined module must derive base ModelService.') # The overrided class is the last one in class_defs mode_class_name = class_defs[-1].__name__ # Load models using registered model definitions registered_models = serving_frontend.get_registered_modelservices() ModelClassDef = registered_models[mode_class_name]