コード例 #1
0
def test_adapt_to_ts_format(path_exists, make_dir, subprocess_check_call,
                            set_python_path):
    handler_service = Mock()

    torchserve._adapt_to_ts_format(handler_service)

    path_exists.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY)
    make_dir.assert_called_once_with(torchserve.DEFAULT_TS_MODEL_DIRECTORY)

    model_archiver_cmd = [
        "torch-model-archiver",
        "--model-name",
        torchserve.DEFAULT_TS_MODEL_NAME,
        "--handler",
        handler_service,
        "--serialized-file",
        os.path.join(environment.model_dir,
                     torchserve.DEFAULT_TS_MODEL_SERIALIZED_FILE),
        "--export-path",
        torchserve.DEFAULT_TS_MODEL_DIRECTORY,
        "--extra-files",
        os.path.join(environment.model_dir, torchserve.DEFAULT_TS_CODE_DIR,
                     environment.Environment().module_name + ".py"),
        "--version",
        "1",
    ]

    subprocess_check_call.assert_called_once_with(model_archiver_cmd)
    set_python_path.assert_called_once_with()
コード例 #2
0
def _generate_mms_config_properties():
    env = environment.Environment()

    user_defined_configuration = {
        "default_response_timeout": env.model_server_timeout,
        "default_workers_per_model": env.model_server_workers,
        "inference_address":
        "http://0.0.0.0:{}".format(env.inference_http_port),
        "management_address":
        "http://0.0.0.0:{}".format(env.management_http_port),
        "vmargs": "-XX:-UseContainerSupport",
    }

    custom_configuration = str()

    for key in user_defined_configuration:
        value = user_defined_configuration.get(key)
        if value:
            custom_configuration += "{}={}\n".format(key, value)

    if ENABLE_MULTI_MODEL:
        default_configuration = utils.read_file(MME_MMS_CONFIG_FILE)
    else:
        default_configuration = utils.read_file(DEFAULT_MMS_CONFIG_FILE)

    return default_configuration + custom_configuration
コード例 #3
0
def _generate_mms_config_properties():
    env = environment.Environment()

    user_defined_configuration = {
        'default_response_timeout': env.model_server_timeout,
        'default_workers_per_model': env.model_server_workers,
        'inference_address':
        'http://0.0.0.0:{}'.format(env.inference_http_port),
        'management_address':
        'http://0.0.0.0:{}'.format(env.management_http_port)
    }

    custom_configuration = str()

    for key in user_defined_configuration:
        value = user_defined_configuration.get(key)
        if value:
            custom_configuration += '{}={}\n'.format(key, value)

    if ENABLE_MULTI_MODEL:
        default_configuration = utils.read_file(MME_MMS_CONFIG_FILE)
    else:
        default_configuration = utils.read_file(DEFAULT_MMS_CONFIG_FILE)

    return default_configuration + custom_configuration
コード例 #4
0
def test_env_module_name(sagemaker_program):
    os.environ[parameters.USER_PROGRAM_ENV] = sagemaker_program
    module_name = environment.Environment().module_name

    del os.environ[parameters.USER_PROGRAM_ENV]

    assert module_name == "program"
コード例 #5
0
def _adapt_to_ts_format(handler_service):
    if not os.path.exists(DEFAULT_TS_MODEL_DIRECTORY):
        os.makedirs(DEFAULT_TS_MODEL_DIRECTORY)

    model_archiver_cmd = [
        "torch-model-archiver",
        "--model-name",
        DEFAULT_TS_MODEL_NAME,
        "--handler",
        handler_service,
        "--serialized-file",
        os.path.join(environment.model_dir, DEFAULT_TS_MODEL_SERIALIZED_FILE),
        "--export-path",
        DEFAULT_TS_MODEL_DIRECTORY,
        "--extra-files",
        os.path.join(environment.model_dir, DEFAULT_TS_CODE_DIR,
                     environment.Environment().module_name + ".py"),
        "--version",
        "1",
    ]

    logger.info(model_archiver_cmd)
    subprocess.check_call(model_archiver_cmd)

    _set_python_path()
コード例 #6
0
    def validate_and_initialize(self, model_dir):  # type: () -> None
        """Validates the user module against the SageMaker inference contract.
        Load the model as defined by the ``model_fn`` to prepare handling predictions.

        NOTE: This still uses environment values from the legacy sagemaker-containers. This should be removed.
        """
        self._environment = environment.Environment()
        self._validate_user_module_and_set_functions()
        self._model = self._model_fn(model_dir)
        self._initialized = True
コード例 #7
0
    def validate_and_initialize(self):  # type: () -> None
        """Validates the user module against the SageMaker inference contract.

        Load the model as defined by the ``model_fn`` to prepare handling predictions.

        """
        if not self._initialized:
            self._environment = environment.Environment()
            self._validate_user_module_and_set_functions()
            self._model = self._model_fn(environment.model_dir)
            self._initialized = True
コード例 #8
0
def test_env():
    env = environment.Environment()

    assert environment.base_dir.endswith('/opt/ml')
    assert environment.model_dir.endswith('/opt/ml/model')
    assert environment.code_dir.endswith('opt/ml/model/code')
    assert env.module_name == 'main'
    assert env.model_server_timeout == 20
    assert env.model_server_workers == '8'
    assert env.default_accept == 'text/html'
    assert env.http_port == '1738'
    assert env.safe_port_range == '1111-2222'
コード例 #9
0
def test_env():
    env = environment.Environment()

    assert environment.base_dir.endswith("/opt/ml")
    assert environment.model_dir.endswith("/opt/ml/model")
    assert environment.code_dir.endswith("opt/ml/model/code")
    assert env.module_name == "main"
    assert env.model_server_timeout == 20
    assert env.model_server_workers == "8"
    assert env.default_accept == "text/html"
    assert env.inference_http_port == "1738"
    assert env.management_http_port == "1738"
    assert env.safe_port_range == "1111-2222"
コード例 #10
0
    def validate_and_initialize(self,
                                model_dir=environment.model_dir,
                                gpu_id=None):  # type: () -> None
        """Validates the user module against the SageMaker inference contract.

        Load the model as defined by the ``model_fn`` to prepare handling predictions.

        """
        if not self._initialized:
            self._environment = environment.Environment()
            self._validate_user_module_and_set_functions()
            num_args_model_fn = len(inspect.getargspec(self._model_fn).args)
            if num_args_model_fn == 2:
                self._model = self._model_fn(model_dir, gpu_id=gpu_id)
            else:
                self._model = self._model_fn(model_dir)
            self._initialized = True
コード例 #11
0
    def _user_module_transformer():
        user_module = importlib.import_module(
            environment.Environment().module_name)

        if hasattr(user_module, 'transform_fn'):
            return Transformer(
                default_inference_handler=DefaultMXNetInferenceHandler())

        model_fn = getattr(user_module, 'model_fn',
                           DefaultMXNetInferenceHandler().default_model_fn)

        model = model_fn(environment.model_dir)
        if isinstance(model, mx.module.BaseModule):
            return MXNetModuleTransformer()
        elif isinstance(model, mx.gluon.block.Block):
            return Transformer(
                default_inference_handler=DefaultGluonBlockInferenceHandler())
        else:
            raise ValueError('Unsupported model type: {}'.format(
                model.__class__.__name__))
コード例 #12
0
    parser.add_argument('--batch_size',
                        type=int,
                        default=4,
                        metavar='BS',
                        help='batch size (default: 4)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.001,
                        metavar='LR',
                        help='initial learning rate (default: 0.001)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.9,
                        metavar='M',
                        help='momentum (default: 0.9)')
    parser.add_argument('--dist_backend',
                        type=str,
                        default='gloo',
                        help='distributed backend (default: gloo)')

    env = environment.Environment()
    parser.add_argument('--hosts', type=list, default=env.hosts)
    parser.add_argument('--current-host', type=str, default=env.current_host)
    parser.add_argument('--model-dir', type=str, default=env.model_dir)
    parser.add_argument('--data-dir',
                        type=str,
                        default=env.channel_input_dirs.get('training'))
    parser.add_argument('--num-gpus', type=int, default=env.num_gpus)

    _train(parser.parse_args())