コード例 #1
0
ファイル: test_model_review.py プロジェクト: chrinide/a2ml
def test_score_actuals_for_candidate_prediction():
    # Prediction data:
    # { 'prediction_id':'bef9be07-5534-434e-ab7c-c379d8fcfe77', 'species':'versicolor' },
    # { 'prediction_id':'f61b1bbc-6f7b-4e7e-9a3b-6acb6e1462cd', 'species':'virginica' }
    model_path = 'tests/fixtures/test_score_actuals/pr_can/candidate'
    prediction_group_id = '272B088D17A7490'

    # Primary prediction data:
    # { 'prediction_id':'09aaa96b-5d9c-4c45-ab04-726da868624b', 'species':'virginica' },
    # { 'prediction_id':'5e5ad22b-6789-47c6-9a4d-a3a998065127', 'species':'virginica' }
    primary_model_path = 'tests/fixtures/test_score_actuals/pr_can/primary'
    primary_prediction_group_id = 'A4FD5B64FEE5434'

    for actuals_path in glob.glob(model_path +
                                  '/predictions/*_actuals.feather.zstd'):
        os.remove(actuals_path)

    actuals = [{
        'prediction_id': '09aaa96b-5d9c-4c45-ab04-726da868624b',
        'actual': 'versicolor'
    }, {
        'prediction_id': '5e5ad22b-6789-47c6-9a4d-a3a998065127',
        'actual': 'virginica'
    }]

    res = ModelReview({
        'model_path': model_path
    }).add_actuals(actual_records=actuals,
                   prediction_group_id=prediction_group_id,
                   primary_prediction_group_id=primary_prediction_group_id,
                   primary_model_path=primary_model_path,
                   calc_score=True)

    assert type(res) == dict
    assert res['accuracy'] == 1.0

    actual_files = glob.glob(model_path +
                             '/predictions/*_actuals.feather.zstd')
    assert len(actual_files) == 1
    actual_file = actual_files[0]
    assert str(datetime.date.today()) in actual_file

    stored_actuals = DataFrame({})
    stored_actuals.loadFromFeatherFile(actual_file)
    assert 'prediction_group_id' in stored_actuals.columns

    stored_actuals = json.loads(
        stored_actuals.df.sort_values(by=['prediction_id']).to_json(
            orient='records'))

    assert stored_actuals[0][
        'prediction_id'] == 'bef9be07-5534-434e-ab7c-c379d8fcfe77'
    assert stored_actuals[0]['prediction_group_id'] == prediction_group_id
    assert stored_actuals[0]['species'] == 'versicolor'

    assert stored_actuals[1][
        'prediction_id'] == 'f61b1bbc-6f7b-4e7e-9a3b-6acb6e1462cd'
    assert stored_actuals[1]['prediction_group_id'] == prediction_group_id
    assert stored_actuals[1]['species'] == 'virginica'
コード例 #2
0
ファイル: test_model_review.py プロジェクト: chrinide/a2ml
def test_score_actuals_with_not_full_actuals():
    model_path = 'tests/fixtures/test_score_actuals'

    for actuals_path in glob.glob(model_path +
                                  '/predictions/*_actuals.feather.zstd'):
        os.remove(actuals_path)

    actuals = [
        {
            'prediction_id': '5c93079c-00c9-497a-8967-53fa0dd02054',
            'actual': False
        },
        {
            'prediction_id': 'b1bf9ebf-0277-4771-9bc5-236690a21194',
            'actual': False
        },
        {
            'prediction_id': 'f61b1bbc-6f7b-4e7e-9a3b-6acb6e1462cd',
            'actual': True
        },
    ]

    actual_date = datetime.date.today() - datetime.timedelta(days=1)

    res = ModelReview({
        'model_path': model_path
    }).add_actuals(actuals_path=None,
                   actual_records=actuals,
                   actual_date=actual_date)
    actual_files = glob.glob(model_path +
                             '/predictions/*_actuals.feather.zstd')
    assert len(actual_files) > 0
    assert str(actual_date) in actual_files[0]

    stored_actuals = DataFrame({})
    stored_actuals.loadFromFeatherFile(actual_files[0])
    assert 'prediction_group_id' in stored_actuals.columns

    stored_actuals = json.loads(
        stored_actuals.df.sort_values(by=['prediction_id']).to_json(
            orient='records'))

    assert len(stored_actuals) == len(actuals)  #+ 1

    assert stored_actuals[0][
        'prediction_id'] == '5c93079c-00c9-497a-8967-53fa0dd02054'
    assert stored_actuals[0][
        'prediction_group_id'] == '2ab1e430-6082-4465-b057-3408d36de144'
    assert stored_actuals[0]['feature1'] == 1
    assert stored_actuals[0]['income'] == False

    assert stored_actuals[1][
        'prediction_id'] == 'b1bf9ebf-0277-4771-9bc5-236690a21194'
    assert stored_actuals[1][
        'prediction_group_id'] == '2ab1e430-6082-4465-b057-3408d36de144'
    assert stored_actuals[1]['feature1'] == 1.1
    assert stored_actuals[1]['income'] == False

    assert stored_actuals[2][
        'prediction_id'] == 'f61b1bbc-6f7b-4e7e-9a3b-6acb6e1462cd'
    assert stored_actuals[2][
        'prediction_group_id'] == '03016c26-f69a-416f-817f-4c58cd69d675'
    assert stored_actuals[2]['feature1'] == 1.3
    assert stored_actuals[2]['income'] == True