Ejemplo n.º 1
0
def test_learn_with_id(client, app, regression, lin_reg):

    client.post('/api/predict',
                data=json.dumps({
                    'id': 42,
                    'features': {
                        'x': 1
                    }
                }),
                content_type='application/json')

    # Check the sample has been stored
    with app.app_context():
        shelf = storage.get_db()
        assert sorted(
            shelf['#42'].keys()) == ['features', 'model', 'prediction']
        assert shelf['#42']['features'] == {'x': 1}

    r = client.post('/api/learn',
                    data=json.dumps({
                        'id': 42,
                        'ground_truth': True
                    }),
                    content_type='application/json')
    assert r.status_code == 201

    # Check the sample has now been removed
    with app.app_context():
        shelf = storage.get_db()
        assert '#42' not in shelf
Ejemplo n.º 2
0
def test_delete_model(client, app, regression):

    # Upload a model
    model = linear_model.LinearRegression()
    client.post('/api/model/healthy-banana', data=pickle.dumps(model))

    with app.app_context():
        assert 'models/healthy-banana' in storage.get_db()

    # Delete it
    client.delete('/api/model/healthy-banana')

    with app.app_context():
        assert 'models/healthy-banana' not in storage.get_db()
Ejemplo n.º 3
0
def test_model_upload(client, app, regression):

    # Instantiate a model
    model = linear_model.LinearRegression()
    probe = uuid.uuid4()
    model.probe = probe

    # Upload the model
    r = client.post('/api/model/healthy-banana', data=pickle.dumps(model))
    assert r.status_code == 201
    assert r.json == {'name': 'healthy-banana'}

    # Check that the model has been added to the shelf
    with app.app_context():
        shelf = storage.get_db()
        assert isinstance(shelf['models/healthy-banana'], linear_model.LinearRegression)
        assert shelf['models/healthy-banana'].probe == probe

    # Check that the model can be retrieved via the API with it's name
    model = pickle.loads(client.get('/api/model/healthy-banana').get_data())
    assert isinstance(model, linear_model.LinearRegression)
    assert model.probe == probe

    # Check that the model can be retrieved via the API by default
    model = pickle.loads(client.get('/api/model').get_data())
    assert isinstance(model, linear_model.LinearRegression)
    assert model.probe == probe
Ejemplo n.º 4
0
def test_predict_with_id(client, app, regression, lin_reg):

    r = client.post('/api/predict',
        data=json.dumps({'features': {}, 'id': '90210'}),
        content_type='application/json'
    )
    assert r.status_code == 201
    assert r.json == {'model': 'lin-reg', 'prediction': 0}

    with app.app_context():
        shelf = storage.get_db()
        assert '#90210' in shelf
Ejemplo n.º 5
0
def test_init(client, app):
    r = client.post('/api/init',
                    data=json.dumps({'flavor': 'regression'}),
                    content_type='application/json')
    assert r.status_code == 201

    with app.app_context():
        assert storage.get_db()['flavor'].name == 'regression'

    assert client.get('/api/init').json == {
        'storage': app.config['STORAGE_BACKEND'],
        'flavor': 'regression',
        'creme_version': creme.__version__
    }
Ejemplo n.º 6
0
def test_add_model(app):
    runner = app.test_cli_runner()

    # Pickle a model
    model = linear_model.LinearRegression()
    probe = uuid.uuid4()
    model.probe = probe
    with open('tmp.pkl', 'wb') as f:
        pickle.dump(model, f)

    # Add the model to the shelf through the CLI
    result = runner.invoke(cli.add_model, ['tmp.pkl', '--name', 'banana'])
    assert result.exit_code == 0

    # Check that the model has been added to the shelf
    with app.app_context():
        db = storage.get_db()
        assert isinstance(db['models/banana'], linear_model.LinearRegression)
        assert db['models/banana'].probe == probe

    # Delete the pickle
    os.remove('tmp.pkl')