def test_model(sagemaker_session): model = ChainerModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session) predictor = model.deploy(1, GPU) assert isinstance(predictor, ChainerPredictor)
def test_serving_calls_model_fn_once(docker_image, sagemaker_local_session): script_path = os.path.join(resources_path, 'call_model_fn_once.py') model_path = 'file://{}'.format( os.path.join(resources_path, 'model.tar.gz')) model = ChainerModel(model_path, 'unused/dummy-role', script_path, image=docker_image, model_server_workers=2, sagemaker_session=sagemaker_local_session) with test_utils.local_mode_lock(): try: predictor = model.deploy(1, 'local') predictor.accept = None predictor.deserializer = BytesDeserializer() # call enough times to ensure multiple requests to a worker for i in range(3): # will return 500 error if model_fn called during request handling response = predictor.predict(b'input') assert response == b'output' finally: predictor.delete_endpoint()
def test_model_prepare_container_def_accelerator_error(sagemaker_session): model = ChainerModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session) with pytest.raises(ValueError): model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
def test_model_prepare_container_def_no_instance_type_or_image(): model = ChainerModel(MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH) with pytest.raises(ValueError) as e: model.prepare_container_def() expected_msg = "Must supply either an instance type (for choosing CPU vs GPU) or an image URI." assert expected_msg in str(e)
def test_model_empty_framework_version(warning, sagemaker_session): model = ChainerModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, framework_version=None, ) assert model.framework_version == defaults.CHAINER_VERSION warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
def test_model_py2_warning(warning, sagemaker_session): model = ChainerModel( MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, py_version="py2", ) assert model.py_version == "py2" warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
def test_model_custom_serialization(sagemaker_session, chainer_version, chainer_py_version): model = ChainerModel( "s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session, framework_version=chainer_version, py_version=chainer_py_version, ) custom_serializer = Mock() custom_deserializer = Mock() predictor = model.deploy( 1, CPU, serializer=custom_serializer, deserializer=custom_deserializer, ) assert isinstance(predictor, ChainerPredictor) assert predictor.serializer is custom_serializer assert predictor.deserializer is custom_deserializer