Пример #1
0
def test_onnx_fails_fast():
    models = {'onnx': 's3://bucket/prefix/whatever.onnx'}

    with pytest.raises(ValueError) as e:
        ModelLoader.load(models)

    assert 'Convert ONNX model' in str(e.value)
Пример #2
0
def test_invalid_model_path_input():
    """
    Test to ensure that folder being created is removed,
    when path is invalid
    """
    models = {'squeezenet_v1': 'invalid_model_file_path.model'}
    with pytest.raises(Exception) as e:
        ModelLoader.load(models)
    assert not os.path.exists('invalid_model_file_path')
Пример #3
0
    def _arg_process(self):
        """Process arguments before starting service or create application.
        """
        try:
            # Port
            self.port = int(self.args.port) if self.args.port else 8080
            self.host = self.args.host or '127.0.0.1'

            # Load models
            models = ModelLoader.load(self.args.models)

            # Register user defined model service or default mxnet_vision_service
            manifest = models[0][3]
            service_file = os.path.join(models[0][2],
                                        manifest['Model']['Service'])

            class_defs = self.serving_frontend.register_module(
                self.args.service or service_file)
            class_defs = list(
                filter(lambda c: len(c.__subclasses__()) == 0, class_defs))

            if len(class_defs) != 1:
                raise Exception(
                    'There should be one user defined service derived from ModelService.'
                )
            model_class_name = class_defs[0].__name__

            # Load models using registered model definitions
            registered_models = self.serving_frontend.get_registered_modelservices(
            )
            ModelClassDef = registered_models[model_class_name]

            self.serving_frontend.load_models(models, ModelClassDef, self.gpu)

            if len(self.args.models) > 5:
                raise Exception('Model number exceeds our system limits: 5')

            # Setup endpoint
            openapi_endpoints = self.serving_frontend.setup_openapi_endpoints(
                self.host, self.port)

            # Generate client SDK
            if self.args.gen_api is not None:
                ClientSDKGenerator.generate(openapi_endpoints,
                                            self.args.gen_api)

            # Generate metrics to target location (log, csv ...), default to log
            MetricsManager.start(self.args.metrics_write_to, self.args.models,
                                 Lock())

        except Exception as e:
            logger.error('Failed to process arguments: ' + str(e))
            exit(1)
    def _arg_process(self):
        """Process arguments before starting service or create application.
        """
        try:
            # Port
            self.port = int(self.args.port) if self.args.port else 8080
            self.host = self.args.host or '127.0.0.1'

            # Load models
            models = ModelLoader.load(self.args.models)

            # Register user defined model service or default mxnet_vision_service
            manifest = models[0][3]
            service_file = os.path.join(models[0][2], manifest['Model']['Service'])

            class_defs = self.serving_frontend.register_module(self.args.service or service_file)
            class_defs = list(filter(lambda c: len(c.__subclasses__()) == 0, class_defs))

            if len(class_defs) != 1:
                raise Exception('There should be one user defined service derived from ModelService.')
            model_class_name = class_defs[0].__name__

            # Load models using registered model definitions
            registered_models = self.serving_frontend.get_registered_modelservices()
            ModelClassDef = registered_models[model_class_name]
            
            self.serving_frontend.load_models(models, ModelClassDef, self.gpu)
            
            if len(self.args.models) > 5:
                raise Exception('Model number exceeds our system limits: 5')
            
            # Setup endpoint
            openapi_endpoints = self.serving_frontend.setup_openapi_endpoints(self.host, self.port)

            # Generate client SDK
            if self.args.gen_api is not None:
                ClientSDKGenerator.generate(openapi_endpoints, self.args.gen_api)

            # Generate metrics to target location (log, csv ...), default to log
            MetricsManager.start(self.args.metrics_write_to, self.args.models, Lock())

        except Exception as e:
            logger.error('Failed to process arguments: ' + str(e))
            exit(1)
def start_serving(app_name='mms', args=None):
    """Start service routing.

    Parameters
    ----------
    app_name : str
        App name to initialize mms service.
    args : List of str
        Arguments for starting service. By default it is None
        and commandline arguments will be used. It should follow
        the format recognized by python argparse parse_args method:
        https://docs.python.org/3/library/argparse.html#argparse.ArgumentParser.parse_args.
        An example for mms arguments:
        ['--models', 'resnet-18=path1', 'inception_v3=path2',
         '--gen-api', 'java', '--port', '8080']
        """

    # Parse the given arguments
    arguments = ArgParser.extract_args(args)
    # Download and/or Extract the given model files
    models = ModelLoader.load(arguments.models)
    # Instantiate an MMS object and prepare to start serving
    mms = MMS(app_name, args=arguments, models=models)
    mms.start_model_serving()
Пример #6
0
from mms.serving_frontend import ServingFrontend
from mms.model_loader import ModelLoader

logger = logging.getLogger()
logger.setLevel(logging.INFO)

serving_frontend = ServingFrontend(__name__)

model_path = 'squeezenet_v1.1.model'

if os.environ.get('LAMBDA_TASK_ROOT', False):
    shutil.copyfile('/var/task/squeezenet_v1.1.model', '/tmp/squeezenet_v1.1.model')
    model_path = '/tmp/squeezenet_v1.1.model'

models = ModelLoader.load({'squeezenet': model_path})
    
manifest = models[0][3]
service_file = os.path.join(models[0][2], manifest['Model']['Service'])

class_defs = serving_frontend.register_module(service_file)

if len(class_defs) < 1:
    raise Exception('User defined module must derive base ModelService.')
# The overrided class is the last one in class_defs
mode_class_name = class_defs[-1].__name__

# Load models using registered model definitions
registered_models = serving_frontend.get_registered_modelservices()
ModelClassDef = registered_models[mode_class_name]