def inference_handler(model_dir):
    if not os.path.exists(model_dir):
        raise ValueError(
            "Model directory [{}] does not exist".format(model_dir))
    try:
        from sagemaker_inference.default_handler_service import DefaultHandlerService
    except ImportError:
        raise click.UsageError(
            "Install sagemaker-inference to use local inference")

    sys.path.insert(0, os.path.join(model_dir, 'code'))
    handler = DefaultHandlerService()
    try:
        from mms.context import Context
    except ImportError:
        raise click.UsageError(
            "Install multi-model-server to use local inference")
    context = Context(
        model_name='local-model',
        model_dir=model_dir,
        manifest=None,
        batch_size=None,
        gpu=None,
        mms_version=None
    )
    handler.initialize(context)
    return handler, context
示例#2
0
def test_handle():
    transformer = Mock()
    transformer.transform.return_value = TRANSFORMED_RESULT

    handler_service = DefaultHandlerService(transformer)
    result = handler_service.handle(DATA, CONTEXT)

    assert result == TRANSFORMED_RESULT
    assert transformer.transform.called_once_with(DATA, CONTEXT)
示例#3
0
def test_initialize():
    transformer = Mock()
    properties = {"model_dir": "/opt/ml/models/model-name"}

    def getitem(key):
        return properties[key]

    context = MagicMock()
    context.system_properties.__getitem__.side_effect = getitem
    DefaultHandlerService(transformer).initialize(context)

    assert transformer.validate_and_initialize().called_once()
示例#4
0
def test_initialize():
    transformer = Mock()

    DefaultHandlerService(transformer).initialize()

    assert transformer.validate_and_initialize().called_once()
示例#5
0
def test_default_handler_service_custom_transformer():
    transformer = Mock()

    handler_service = DefaultHandlerService(transformer)

    assert handler_service._service == transformer
示例#6
0
def test_default_handler_service(import_lib):
    handler_service = DefaultHandlerService()

    assert isinstance(handler_service._service, Transformer)