def test_transformer_transform_with_client_error(input_fn, predict_fn, output_fn):
    with pytest.raises(_errors.ClientError) as e:
        transform = _transformer.Transformer(
            model_fn=MagicMock(), input_fn=input_fn, predict_fn=predict_fn, output_fn=output_fn
        )

        transform.transform()
    assert e.value.args[0] == error_from_fn
def test_transformer_too_many_custom_methods():
    with pytest.raises(ValueError) as e:
        _transformer.Transformer(input_fn=MagicMock(),
                                 predict_fn=MagicMock(),
                                 output_fn=MagicMock(),
                                 transform_fn=MagicMock())

    assert 'Cannot use transform_fn implementation with input_fn, predict_fn, and/or output_fn' in str(
        e)
def test_transformer_transform_with_unsupported_content_type():
    bad_request = test.request(data=None, content_type='fake/content-type')
    with patch('sagemaker_containers._worker.Request', lambda: bad_request):
        response = _transformer.Transformer().transform()

    assert response.status_code == http_client.UNSUPPORTED_MEDIA_TYPE

    response_body = json.loads(response.response[0].decode('utf-8'))
    assert response_body['error'] == 'UnsupportedFormatError'
    assert bad_request.content_type in response_body['error-message']
def test_transformer_transform_with_unsupported_content_type():
    bad_request = test.request(data=None, headers={"ContentType": "fake/content-type"})
    with patch("sagemaker_containers._worker.Request", lambda: bad_request):
        response = _transformer.Transformer().transform()

    assert response.status_code == http_client.UNSUPPORTED_MEDIA_TYPE

    response_body = json.loads(response.response[0].decode("utf-8"))
    assert response_body["error"] == "UnsupportedFormatError"
    assert bad_request.content_type in response_body["error-message"]
def test_transformer_with_custom_transform_fn():
    model = MagicMock()

    def model_fn(model_dir):
        return model

    transform_fn = MagicMock()

    transform = _transformer.Transformer(model_fn=model_fn, transform_fn=transform_fn)
    transform.initialize()
    transform.transform()

    transform_fn.assert_called_with(model, "13", _content_types.CSV, _content_types.ANY)
def test_transformer_transform_with_unsupported_accept_type():
    def empty_fn(*args):
        pass

    bad_request = test.request(data=None, headers={"Accept": "fake/content-type"})
    with patch("sagemaker_containers._worker.Request", lambda: bad_request):
        t = _transformer.Transformer(model_fn=empty_fn, input_fn=empty_fn, predict_fn=empty_fn)
        response = t.transform()

    assert response.status_code == http_client.NOT_ACCEPTABLE

    response_body = json.loads(response.response[0].decode("utf-8"))
    assert response_body["error"] == "UnsupportedFormatError"
    assert bad_request.accept in response_body["error-message"]
def test_transformer_transform():
    model_fn, input_fn, predict_fn = (MagicMock(), MagicMock(), MagicMock())
    output_fn = MagicMock(return_value="response")

    transform = _transformer.Transformer(
        model_fn=model_fn, input_fn=input_fn, predict_fn=predict_fn, output_fn=output_fn
    )

    transform.initialize()
    assert transform.transform() == "response"

    input_fn.assert_called_with("42", _content_types.JSON)
    predict_fn.assert_called_with(input_fn(), model_fn())
    output_fn.assert_called_with(predict_fn(), _content_types.NPY)
def test_transformer_with_custom_transform_fn():
    model = MagicMock()

    def model_fn(model_dir):
        return model

    transform_fn = MagicMock()

    transform = _transformer.Transformer(model_fn=model_fn,
                                         transform_fn=transform_fn)
    transform.initialize()
    transform.transform()

    transform_fn.assert_called_with(model, request.content,
                                    request.content_type, request.accept)
def test_transformer_transform(response):
    model_fn, input_fn, predict_fn = (MagicMock(), MagicMock(), MagicMock())
    output_fn = MagicMock(return_value=response)

    transform = _transformer.Transformer(model_fn=model_fn,
                                         input_fn=input_fn,
                                         predict_fn=predict_fn,
                                         output_fn=output_fn)

    transform.initialize()
    assert transform.transform() == response

    input_fn.assert_called_with(request.content, request.content_type)
    predict_fn.assert_called_with(input_fn(), model_fn())
    output_fn.assert_called_with(predict_fn(), request.accept)
def test_transformer_transform_with_unsupported_accept_type():
    def empty_fn(*args):
        pass

    bad_request = test.request(data=None, accept='fake/content_type')
    with patch('sagemaker_containers._worker.Request', lambda: bad_request):
        t = _transformer.Transformer(model_fn=empty_fn,
                                     input_fn=empty_fn,
                                     predict_fn=empty_fn)
        response = t.transform()

    assert response.status_code == http_client.NOT_ACCEPTABLE

    response_body = json.loads(response.response[0].decode('utf-8'))
    assert response_body['error'] == 'UnsupportedFormatError'
    assert bad_request.accept in response_body['error-message']
def test_transformer_transform_backwards_compatibility():
    model_fn, input_fn, predict_fn, output_fn = (
        MagicMock(), MagicMock(), MagicMock(),
        MagicMock(return_value=(0, _content_types.ANY)))

    transform = _transformer.Transformer(model_fn=model_fn,
                                         input_fn=input_fn,
                                         predict_fn=predict_fn,
                                         output_fn=output_fn)

    transform.initialize()

    assert transform.transform().status_code == http_client.OK

    input_fn.assert_called_with('13', _content_types.CSV)
    predict_fn.assert_called_with(input_fn(), model_fn())
    output_fn.assert_called_with(predict_fn(), _content_types.ANY)
def test_transformer_transform_backwards_compatibility():
    model_fn, input_fn, predict_fn, output_fn = (MagicMock(), MagicMock(),
                                                 MagicMock(),
                                                 MagicMock(return_value=(0,
                                                                         1)))

    transform = _transformer.Transformer(model_fn=model_fn,
                                         input_fn=input_fn,
                                         predict_fn=predict_fn,
                                         output_fn=output_fn)

    transform.initialize()

    assert transform.transform().status_code == _status_codes.OK

    input_fn.assert_called_with(request.content, request.content_type)
    predict_fn.assert_called_with(input_fn(), model_fn())
    output_fn.assert_called_with(predict_fn(), request.accept)
def test_transformer_initialize_with_client_error():
    with pytest.raises(_errors.ClientError) as e:
        _transformer.Transformer(model_fn=fn_with_error).initialize()
    assert e.value.args[0] == error_from_fn
def test_transformer_initialize_with_default_model_fn():
    with pytest.raises(NotImplementedError):
        _transformer.Transformer().initialize()
def test_initialize():
    model_fn = MagicMock()

    _transformer.Transformer(model_fn=model_fn).initialize()

    model_fn.assert_called_with(_env.model_dir)
def test_transformer_with_default_predict_fn():
    with pytest.raises(NotImplementedError):
        _transformer.Transformer().transform()