Example #1
0
def test_ensure_schema_exists_creates_missing_schema(
        monkeypatch: MonkeyPatch, caplog: LogCaptureFixture) -> None:
    new_schema = "new_schema"
    mock_get_schema_names = Mock(return_value=["old_schema"])

    internal_database = Database(DatabaseType.internal)
    internal_database._database_schema = new_schema

    with internal_database.transaction_context() as session:
        original_connection = session.connection()

        def mock_execute_create_schema(
                connection: Connection,
                statement: Executable) -> Optional[Connection]:
            if isinstance(statement, CreateSchema):
                assert statement.element == new_schema
                return None

            return cast(Connection, original_connection(connection, statement))

    monkeypatch.setattr(MSDialect_pyodbc, "get_schema_names",
                        mock_get_schema_names)
    monkeypatch.setattr(Connection, "execute", mock_execute_create_schema)

    with caplog.at_level(logging.INFO):
        ensure_schema_exists(internal_database)
        assert f"Creating schema: {new_schema} in internal database" in caplog.messages

    mock_get_schema_names.assert_called_once()
Example #2
0
def test_store_cleaned_data_in_database() -> None:
    first_run_id = -100
    second_run_id = first_run_id + 1

    query_relevant_cleaned_data = Query(CleanedData).filter(
        CleanedData.c.run_id.in_([first_run_id, second_run_id])
    )  # type: ignore

    delete_test_data(query_relevant_cleaned_data)  # Cleanup data from previously failed/cancelled test run

    cleaned_data = pd.DataFrame(
        [
            {
                "Project_ID": "Test_Project",
                "Contract_ID": "Contract_store_cleaned_data",
                "Wesco_Master_Number": "Test_Master_Number",
                "Date": datetime(2019, 12, 1, 0, 0),
                "Date_YYYYMM": 201912,
                "Item_ID": -1,
                "Unit_Cost": 2.0,
                "Order_Quantity": 10.0,
                "Order_Cost": 20.0,
            },
            {
                "Project_ID": "Test_Project",
                "Contract_ID": "Contract_store_cleaned_data",
                "Wesco_Master_Number": "Test_Master_Number",
                "Date": datetime(2020, 1, 1, 0, 0),
                "Date_YYYYMM": 202001,
                "Item_ID": -1,
                "Unit_Cost": 0.0,
                "Order_Quantity": 0.0,
                "Order_Cost": 0.0,
            },
        ]
    )
    first_expected_cleaned_data = cleaned_data.assign(run_id=first_run_id).to_dict(orient="records")
    second_expected_cleaned_data = cleaned_data.assign(run_id=second_run_id).to_dict(orient="records")

    internal_database = Database(DatabaseType.internal)

    def _assert_expected_cleaned_data(expected_cleaned_data: List[Dict[str, Any]], message: str) -> None:
        with internal_database.transaction_context() as session:
            result = [row._asdict() for row in query_relevant_cleaned_data.with_session(session).all()]
            assert expected_cleaned_data == result, message

    data_output = DataOutput(RUNTIME_CONFIG, internal_database, Database(DatabaseType.dsx_write))
    _assert_expected_cleaned_data([], "Expect empty cleaned data at beginning of test")

    data_output.store_cleaned_data(cleaned_data, first_run_id)
    _assert_expected_cleaned_data(first_expected_cleaned_data, "Expect only first cleaned data run")

    data_output.store_cleaned_data(cleaned_data, second_run_id)  # This should cleanup the first_expected_cleaned_data
    _assert_expected_cleaned_data(second_expected_cleaned_data, "Expect only second cleaned data run")

    data_output.store_cleaned_data(cleaned_data, first_run_id)  # This should not cleanup newer second run data
    _assert_expected_cleaned_data(first_expected_cleaned_data + second_expected_cleaned_data, "Don't delete newer data")

    delete_test_data(query_relevant_cleaned_data)  # Cleanup data from test run
    _assert_expected_cleaned_data([], "Expect empty cleaned data at end of test")
Example #3
0
def drop_known_tables(database: Database) -> None:
    """Permanently delete tables defined in the platform code and remove and all their data from database.

    These tables are defined as subclasses of :class:`~forecasting_platform.internal_schema.InternalSchemaBase` or
    :class:`~forecasting_platform.dsx_schema.DsxWriteSchemaBase`.
    This will not delete unknown tables, e.g. tables used by other programs or removed from the platform.
    """
    if database.is_disabled():
        logger.error(
            f"Cannot drop tables, because {database} connection is not available"
        )
        return

    defined_table_names = database.get_defined_table_names()
    logger.info(f"Dropping own database tables: {defined_table_names}")

    previous_table_names = database.get_existing_table_names()
    logger.info(f"Previously existing tables: {previous_table_names}")

    database.schema_base_class.metadata.drop_all()

    new_table_names = database.get_existing_table_names()
    logger.info(f"Now existing tables: {new_table_names}")

    assert all(name in defined_table_names
               for name in set(previous_table_names) - set(new_table_names))
def prepare_database() -> Iterator[None]:
    internal_database = Database(DatabaseType.internal)
    with internal_database.transaction_context() as session:
        model_run_with_one_result_row = ForecastModelRun(
            model_name="model_run_with_one_forecast_data_row",
            start=FIXED_TIMESTAMP,
            end=FIXED_TIMESTAMP,
            status=ForecastModelRunStatus.COMPLETED,
        )
        model_run_with_zero_result_rows = ForecastModelRun(
            model_name="model_run_with_zero_forecast_data_rows",
            start=FIXED_TIMESTAMP,
            end=FIXED_TIMESTAMP,
            status=ForecastModelRunStatus.COMPLETED,
        )
        forecast_run = ForecastRun(
            start=FIXED_TIMESTAMP,
            end=FIXED_TIMESTAMP,
            run_type=EngineRunType.production,
            includes_cleaning=True,
            status=ForecastRunStatus.COMPLETED,
            forecast_periods=1,
            prediction_start_month=FIXED_TIMESTAMP,
            model_runs=[
                model_run_with_one_result_row, model_run_with_zero_result_rows
            ],  # type: ignore
        )

        session.add(forecast_run)
        session.flush()
        test_model_run_id = model_run_with_one_result_row.id

        forecast_data = {
            "model_run_id": test_model_run_id,
            "Contract_ID": "test_contract_for_compare_structure_database",
            "Item_ID": -1,
            "Prediction_Start_Month": 20200101,
            "Predicted_Month": 20200101,
            "Prediction_Months_Delta": 0,
            "Prediction_Raw": 0,
            "Prediction_Post": 0,
            "Actual": None,
            "Accuracy": None,
        }

        session.execute(ForecastData.insert().values(forecast_data))

    yield

    # cleanup
    with internal_database.transaction_context() as session:
        forecast_run = session.merge(forecast_run)
        # following delete will cascade and delete also forecast_model_runs due to FK relationship
        session.delete(forecast_run)  # type: ignore
        assert (delete_test_data(
            Query(ForecastData).filter(ForecastData.c.model_run_id ==
                                       test_model_run_id)  # type: ignore
        ) == 1), "Cleanup failed, check database for uncleaned data"
Example #5
0
def update_forecast_data_with_cleaned_data_sales(
        internal_database: Database, cleaned_data_run_id: Optional[int],
        cleaned_data_newest_month: Optional[datetime]) -> None:
    """Update actual values for all previous forecasts with newest cleaned data up to newest date in cleaned data.

    Actual values with NULL are set to zero if date is before or equal to the newest date in cleaning data.
    Any dates after newest date are left unchanged.

    Args:
        internal_database: Service to access internal database.
        cleaned_data_run_id: ID of latest successful platform run, which included cleaning.
        cleaned_data_newest_month: Newest month in cleaned data table associated with the ``cleaned_data_run_id``.
    """
    if internal_database.is_disabled():
        logger.warning(
            "Skipping update of previous forecasts due to disabled database")
        return

    assert cleaned_data_run_id is not None, "Invalid program state: Expected cleaned_data_run_id to exist"
    with internal_database.transaction_context() as session:
        updated_existing = session.execute(ForecastData.update().where(
            CleanedData.c.run_id == cleaned_data_run_id).where(
                ForecastData.c.Contract_ID == CleanedData.c.Contract_ID).where(
                    ForecastData.c.Item_ID == CleanedData.c.Item_ID).where(
                        ForecastData.c.Predicted_Month ==
                        CleanedData.c.Date_YYYYMM).values({
                            "Actual":
                            CleanedData.c.Order_Quantity,
                            "Accuracy":
                            compute_accuracy_as_sql(
                                CleanedData.c.Order_Quantity,
                                ForecastData.c.Prediction_Post),
                        }))

    logger.info(
        f"Updated {updated_existing.rowcount} rows of forecast_data with old actual values to"
        f" newest actual values from cleaned_data")

    assert cleaned_data_newest_month, "Invalid program state: Expected cleaned_data_newest_month to exist"
    newest_month = int(
        cleaned_data_newest_month.strftime(PREDICTION_MONTH_FORMAT))
    with internal_database.transaction_context() as session:
        updated_nulls = session.execute(ForecastData.update().where(
            ForecastData.c.Predicted_Month <= newest_month).where(
                ForecastData.c.Actual == None)  # noqa: E711
                                        .values({
                                            "Actual":
                                            0,
                                            "Accuracy":
                                            compute_accuracy_as_sql(
                                                0,
                                                ForecastData.c.Prediction_Post)
                                        }))

    logger.info(
        f"Updated {updated_nulls.rowcount} rows of forecast_data without actual values to"
        f" newest actual values from cleaned_data")
Example #6
0
def test_store_exogenous_features() -> None:
    first_run_id = -90
    second_run_id = first_run_id + 1

    query_relevant_data = Query(ExogenousFeature).filter(
        ExogenousFeature.c.run_id.in_([first_run_id, second_run_id])
    )  # type: ignore

    delete_test_data(query_relevant_data)  # Cleanup data from previously failed/cancelled test run

    exogenous_features = pd.DataFrame(
        [
            {
                "Periodic_Data_Stream": "Test_Data",
                "Airframe": "Test_Airframe",
                "Contract_ID": "Contract_exogenous_feature",
                "Project_ID": "Test_Project",
                "Date": datetime(2019, 12, 1, 0, 0),
                "Value": 20.0,
            },
            {
                "Periodic_Data_Stream": "Test_Data",
                "Airframe": "Test_Airframe",
                "Contract_ID": "Contract_exogenous_feature",
                "Project_ID": "Test_Project",
                "Date": datetime(2020, 1, 1, 0, 0),
                "Value": 0.1,
            },
        ]
    )
    first_expected_data = exogenous_features.assign(run_id=first_run_id).to_dict(orient="records")
    second_expected_data = exogenous_features.assign(run_id=second_run_id).to_dict(orient="records")

    internal_database = Database(DatabaseType.internal)

    def _assert_expected_exogenous_features(expected_exogenous_feature: List[Dict[str, Any]], message: str) -> None:
        with internal_database.transaction_context() as session:
            result = [row._asdict() for row in query_relevant_data.with_session(session).all()]
            assert expected_exogenous_feature == sorted(result, key=itemgetter("run_id")), message

    data_output = DataOutput(RUNTIME_CONFIG, internal_database, Database(DatabaseType.dsx_write))
    _assert_expected_exogenous_features([], "Expect empty exogenous feature data at beginning of test")

    data_output.store_exogenous_features(exogenous_features, first_run_id)
    _assert_expected_exogenous_features(first_expected_data, "Expect only first exogenous feature data run")

    data_output.store_exogenous_features(exogenous_features, second_run_id)
    # This should cleanup the previously inserted exogenous feature data
    _assert_expected_exogenous_features(second_expected_data, "Expect only second exogenous feature data run")

    data_output.store_exogenous_features(exogenous_features, first_run_id)
    # This should not cleanup newer second run data
    _assert_expected_exogenous_features(first_expected_data + second_expected_data, "Don't delete newer data")

    delete_test_data(query_relevant_data)  # Cleanup data from test run
    _assert_expected_exogenous_features([], "Expect empty exogenous feature data at end of test")
def with_cleaned_and_forecast_data_in_database() -> Iterator[Orchestrator]:
    cleanup_cleaned_data_query = Query(CleanedData).filter(
        CleanedData.c.Contract_ID == TEST_CONTRACT)  # type: ignore
    cleanup_forecast_data_query = Query(ForecastData).filter(
        ForecastData.c.Contract_ID == TEST_CONTRACT)  # type: ignore
    delete_test_data(cleanup_cleaned_data_query
                     )  # Cleanup in case of previously failed test
    delete_test_data(cleanup_forecast_data_query
                     )  # Cleanup in case of previously failed test

    with Database(DatabaseType.internal).transaction_context() as session:
        cleaned_data_count = cleanup_cleaned_data_query.with_session(
            session).count()
        forecast_data_count = cleanup_forecast_data_query.with_session(
            session).count()
    assert cleaned_data_count + forecast_data_count == 0, "Found old test data in database when setting up the test"

    runtime_config = RuntimeConfig(
        engine_run_type=EngineRunType.development,
        forecast_periods=1,
        output_location=".",
        prediction_month=pd.Timestamp(year=2020, month=2, day=1),
    )

    orchestrator = Orchestrator(
        runtime_config,
        Mock(spec=DataLoader),
        Mock(spec=DataOutput),
        Database(DatabaseType.internal),
        Mock(spec=Queue),
        Mock(),
        Mock(),
    )

    test_cleaned_data = _setup_cleaned_data(orchestrator)
    test_forecast_data = _setup_forecast_data(orchestrator)
    test_run_id = cast(int, orchestrator._forecast_run_id)

    with Database(DatabaseType.internal).transaction_context() as session:
        session.execute(CleanedData.insert().values(test_cleaned_data))
        session.execute(ForecastData.insert().values(test_forecast_data))

    yield orchestrator

    assert delete_test_data(cleanup_cleaned_data_query) == len(
        test_cleaned_data)
    assert delete_test_data(cleanup_forecast_data_query) == len(
        test_forecast_data)
    assert delete_test_data(
        Query(ForecastRun).filter(
            ForecastRun.id == test_run_id)) == 1  # type: ignore
Example #8
0
def test_ensure_schema_exists(monkeypatch: MonkeyPatch,
                              caplog: LogCaptureFixture) -> None:
    existing_schema = "existing_schema"
    mock_get_schema_names = Mock(return_value=[existing_schema])
    monkeypatch.setattr(MSDialect_pyodbc, "get_schema_names",
                        mock_get_schema_names)

    internal_database = Database(DatabaseType.internal)
    internal_database._database_schema = existing_schema

    with caplog.at_level(logging.INFO):
        ensure_schema_exists(internal_database)
        assert f"Found schema: {existing_schema} in internal database" in caplog.messages

    mock_get_schema_names.assert_called_once()
Example #9
0
def get_last_successful_cleaning_run_id(internal_database: Database) -> int:
    """Returns last successful and completed cleaning run ID.

    Args:
        internal_database: Service to access internal database.
    """
    with internal_database.transaction_context() as session:
        cleaning_run_id = (
            session.query(func.max(ForecastRun.id))  # type: ignore
            .filter(ForecastRun.status == ForecastRunStatus.COMPLETED).filter(
                ForecastRun.includes_cleaning == True).scalar())
        logger.debug(
            f"Found highest completed cleaning run ID: {cleaning_run_id}")

        cleaned_data_ids = session.query(distinct(
            CleanedData.c.run_id)).all()  # type: ignore
        logger.debug(
            f"Found cleaned data IDs in internal database: {cleaned_data_ids}")

        if not cleaning_run_id or len(
                cleaned_data_ids
        ) != 1 or cleaned_data_ids[0][0] != cleaning_run_id:
            raise DataException(
                "Cannot determine valid cleaning data, "
                "please re-run cleaning steps, "
                'e.g. by providing the "--force-reload" parameter')
        return int(cleaning_run_id)
def assert_forecast_data_for_model_run(internal_database: Database,
                                       model_run_id: int,
                                       run: ForecastRun) -> None:
    """Assert count of rows in forecast data table of internal database against expected structure."""
    with internal_database.transaction_context() as session:
        model_name = (
            session.query(ForecastModelRun)  # type: ignore
            .filter(ForecastModelRun.id == model_run_id).first().model_name)

    run_parameters = ExpectedForecastStructureParameters(
        account_name=model_name,
        forecast_periods=run.forecast_periods,
        prediction_month=run.prediction_start_month,
    )
    expected_forecast_structure = get_expected_forecast_structure(
        run_parameters)
    expected_number_of_rows = expected_forecast_structure.shape[0]
    forecast_data_count = get_forecast_data_count_for_model_run(
        internal_database, model_run_id)

    assert forecast_data_count == expected_number_of_rows, (
        f"Forecast data count ({forecast_data_count}) for {model_name} (model_run={model_run_id}) "
        f"does not match the expectation ({expected_number_of_rows}) defined by {run_parameters})"
    )
    logger.info(
        f"Asserted forecast data count ({forecast_data_count}) for {model_name} (model_run={model_run_id})"
    )
 def test_forecast_data_in_dsx_write_database(self, production_forecast: OrchestratorResult) -> None:
     with Database(DatabaseType.dsx_write).transaction_context() as session:
         dsx_output_data_count = session.query(DsxOutput).count()  # type: ignore
         # Currently there is no good identifier to get the results of this specific test run
         assert (
             dsx_output_data_count >= self.EXPECTED_FORECAST_DATA_COUNT
         )  # Further checks are done in the end-to-end test
Example #12
0
def test_defined_tables_internal_database() -> None:
    assert Database(DatabaseType.internal).get_defined_table_names() == [
        "ml_internal.cleaned_data",
        "ml_internal.exogenous_feature",
        "ml_internal.forecast_data",
        "ml_internal.forecast_model_run",
        "ml_internal.forecast_run",
    ]
    def test_forecast_data_in_internal_database(self, production_forecast: OrchestratorResult) -> None:
        model_run_id = production_forecast.forecast_result.model_run_id
        expected = [
            (
                model_run_id,
                ContractID("Contract_378"),
                64987,
                202001,
                202001,
                0,
                pytest.approx(29.589586),
                pytest.approx(29.589586),
                0,
                0.0,
            ),
            (
                model_run_id,
                ContractID("Contract_378"),
                64987,
                202001,
                202002,
                1,
                pytest.approx(60.164522),
                pytest.approx(60.164522),
                138,
                pytest.approx(0.4359747),
            ),
            (
                model_run_id,
                ContractID("Contract_378"),
                64987,
                202001,
                202003,
                2,
                pytest.approx(25.365093),
                pytest.approx(25.365093),
                0,
                0.0,
            ),
        ]

        with Database(DatabaseType.internal).transaction_context() as session:
            forecast_data = (
                session.query(ForecastData)  # type: ignore
                .filter(ForecastData.c.model_run_id == model_run_id)
                .filter(ForecastData.c.Contract_ID == ContractID("Contract_378"))
                .filter(ForecastData.c.Item_ID == 64987)
                .all()
            )
            assert forecast_data == expected

            forecast_data_count = (
                session.query(ForecastData)  # type: ignore
                .filter(ForecastData.c.model_run_id == model_run_id)
                .count()
            )
            assert forecast_data_count == self.EXPECTED_FORECAST_DATA_COUNT
Example #14
0
def ensure_schema_exists(database: Database) -> None:
    """Ensure database schema configured to be used with database exists and create it if missing."""
    if database.is_disabled():
        logger.error(
            f"Cannot setup schema, because {database} connection is not available"
        )
        return

    with database.transaction_context() as session:
        connection = session.connection()
        schema_name = database._database_schema
        existing_schemas = connection.dialect.get_schema_names(connection)

        if schema_name in existing_schemas:
            logger.info(f"Found schema: {schema_name} in {database}")
            return

        logger.info(f"Creating schema: {schema_name} in {database}")
        connection.execute(schema.CreateSchema(schema_name))
Example #15
0
    def test_drop_tables(self, database_type: DatabaseType,
                         monkeypatch: MonkeyPatch) -> None:
        mock_drop_all = Mock()
        database = Database(database_type)
        monkeypatch.setattr(database.schema_base_class.metadata, "drop_all",
                            mock_drop_all)

        drop_known_tables(database)

        mock_drop_all.assert_called_once()
Example #16
0
    def test_ensure_tables(self, database_type: DatabaseType,
                           monkeypatch: MonkeyPatch) -> None:
        database = Database(database_type)

        mock_create_all = Mock()
        monkeypatch.setattr(database.schema_base_class.metadata, "create_all",
                            mock_create_all)

        ensure_tables_exist(database)

        mock_create_all.assert_called_once()
def assert_dsx_output_count(dsx_write_database: Database,
                            expected_count: int) -> None:
    """Assert count of of values saved in dsx write database."""
    with dsx_write_database.transaction_context() as session:
        dsx_data_count = session.query(DsxOutput).count()  # type: ignore

    assert dsx_data_count == expected_count, (
        f"Total count of  values ({dsx_data_count}) in {DsxOutput.name} table of {dsx_write_database} "
        f"does not match the expectation ({expected_count})")
    logger.info(
        f"Asserted total count of values ({expected_count}) in {DsxOutput.name} "
        f"of {dsx_write_database}")
def get_forecast_data_count_for_model_run(internal_database: Database,
                                          model_run_id: int) -> int:
    """Return count of forecast_data entries associated with model run."""
    with internal_database.transaction_context() as session:
        forecast_data_count = (
            session.query(ForecastData)  # type: ignore
            .filter(ForecastData.c.model_run_id == model_run_id).count())

    logger.debug(
        f"Found {forecast_data_count} forecast data rows associated with forecast model run: {model_run_id}"
    )
    return cast(int, forecast_data_count)
Example #19
0
def test_not_store_backward_in_database(
    data_loader: DataLoader, caplog: LogCaptureFixture, monkeypatch: MonkeyPatch, tmp_path: Path,
) -> None:
    internal_database = Database(DatabaseType.internal)
    dsx_write_database = Database(DatabaseType.dsx_write)
    data_output = DataOutput(RUNTIME_CONFIG, internal_database, dsx_write_database)

    model_config_account_1 = ModelConfigAccount1(runtime_config=RUNTIME_CONFIG, data_loader=data_loader)
    account_data = data_loader.load_account_data(model_config_account_1, -1)
    model_run = ForecastModelRun()
    model_config_account_1.forecast_path = tmp_path
    with caplog.at_level(logging.DEBUG):
        returned_model_run = data_output.store_forecast(
            model_config=model_config_account_1,
            model_run=model_run,
            account_data=account_data,
            forecast_raw=TEST_ACCOUNT_1_RAW_DATA,
            forecast_post=TEST_ACCOUNT_1_POST_DATA,
            actuals_newest_month=datetime(2019, 10, 1, 0, 0),
        )
    assert RUNTIME_CONFIG.engine_run_type == EngineRunType.backward
    assert returned_model_run is model_run
    assert "Skip storing forecast in internal database for backward run." in caplog.messages
def get_last_successful_production_run(
        internal_database: Database) -> ForecastRun:
    """Return last successful production run as detached object."""
    with internal_database.transaction_context() as session:
        session.expire_on_commit = False
        forecast_run = (
            session.query(ForecastRun)  # type: ignore
            .filter(ForecastRun.status == ForecastRunStatus.COMPLETED).filter(
                ForecastRun.run_type == EngineRunType.production).order_by(
                    desc(ForecastRun.id)).limit(1).first())

    logger.debug(
        f"Found last completed production forecast run: {forecast_run}")
    return cast(ForecastRun, forecast_run)
Example #21
0
def get_newest_cleaned_data_month(internal_database: Database,
                                  cleaned_data_run_id: int) -> datetime:
    """Returns most recent month from cleaned data table of internal database.

    Args:
        internal_database: Service to access internal database.
        cleaned_data_run_id: ID of latest successful platform run, which included cleaning.
    """
    with internal_database.transaction_context() as session:
        return cast(
            datetime,
            session.query(func.max(CleanedData.c.Date))  # type: ignore
            .filter(CleanedData.c.run_id == cleaned_data_run_id).scalar(),
        )
def assert_dsx_output_total_sum(dsx_write_database: Database,
                                expected_total_sum: int) -> None:
    """Assert count of values in dsx write database."""
    with dsx_write_database.transaction_context() as session:
        dsx_total_sum = sum([
            int(result.Value)
            for result in session.query(DsxOutput.c.Value).all()
        ])  # type: ignore

    assert dsx_total_sum == expected_total_sum, (
        f"Total sum of values ({dsx_total_sum}) in {DsxOutput.name} table of {dsx_write_database} "
        f"does not match the expectation ({expected_total_sum})")
    logger.info(
        f"Asserted total sum of values ({expected_total_sum}) in {DsxOutput.name} "
        f"of {dsx_write_database}")
def get_model_run_ids_for_forecast_run(internal_database: Database,
                                       forecast_run_id: int) -> List[int]:
    """Return successfully completed model run IDs associated with the forecast run."""
    with internal_database.transaction_context() as session:
        model_run_ids = [
            result.id for result in (
                session.query(ForecastModelRun)  # type: ignore
                .filter(ForecastModelRun.run_id == forecast_run_id).filter(
                    ForecastModelRun.status ==
                    ForecastModelRunStatus.COMPLETED).all())
        ]
    logger.debug(
        f"Found {len(model_run_ids)} completed model run(s) (id={model_run_ids}) "
        f"associated with forecast run ID: {forecast_run_id}")
    return model_run_ids
Example #24
0
    def test_no_ensure_tables_when_disabled(self, database_type: DatabaseType,
                                            monkeypatch: MonkeyPatch,
                                            caplog: LogCaptureFixture) -> None:
        monkeypatch.setattr(master_config, "db_connection_attempts", 0)

        database = Database(database_type)

        mock_create_all = Mock()
        monkeypatch.setattr(database.schema_base_class.metadata, "create_all",
                            mock_create_all)

        caplog.clear()
        with caplog.at_level(logging.INFO):
            ensure_tables_exist(database)
            assert f"Cannot setup tables, because {database} connection is not available" in caplog.messages

        mock_create_all.assert_not_called()
Example #25
0
def ensure_tables_exist(database: Database) -> None:
    """Ensure tables defined in the platform code exist in the database and create them if they are missing.

    These tables are defined as subclasses of :class:`~forecasting_platform.internal_schema.InternalSchemaBase` or
    :class:`~forecasting_platform.dsx_schema.DsxWriteSchemaBase`.
    This will not update existing tables, e.g. when new columns are added.
    """
    if database.is_disabled():
        logger.error(
            f"Cannot setup tables, because {database} connection is not available"
        )
        return

    logger.info(
        f"Previously existing tables: {database.get_existing_table_names()}")

    database.schema_base_class.metadata.create_all()

    logger.info(f"Now existing tables: {database.get_existing_table_names()}")
def insert_cleaned_data_for_database_test(
        model_config_class: Type[BaseModelConfig],
        test_run_id: int,
        disable_internal_database: bool = False) -> Iterator[None]:
    """Insert cleaned data for integration tests by loading a matching CSV file from anonymized_data_dsx.zip.

    Parameters
    ----------
    model_config_class
        Model config class determines with file will be loaded from anonymized_data_dsx.zip
    test_run_id
        Integer ID which should uniquely identify the cleaned data for a specific test.
    disable_internal_database
        If True will skip this setup step.
    """
    if disable_internal_database:
        yield  # We can skip this function for tests that do not use the real database
        return

    # Cleanup potentially left-over data from previously failed run
    delete_test_data(
        Query(CleanedData).filter(
            CleanedData.c.run_id == test_run_id))  # type: ignore

    # Insert fresh cleaned data for this test
    for contract in model_config_class.CONTRACTS:  # type: ignore
        file_path = (Path(master_config.default_data_loader_location) /
                     master_config.account_processed_data_path /
                     f"DSX_{contract}_Data.csv.gz")
        cleaned_data = pd.read_csv(file_path)
        cleaned_data = cleaned_data[SALES_COLUMNS].assign(run_id=test_run_id)
        Database(DatabaseType.internal).insert_data_frame(
            cleaned_data, CLEANED_DATA_TABLE)

    yield

    # Remove test data from database
    delete_test_data(
        Query(CleanedData).filter(
            CleanedData.c.run_id == test_run_id))  # type: ignore
def delete_test_data(db_query: Query, retry: bool = True) -> int:
    """Delete database objects based on user query. Intended to be used only for integration tests cleanup.

    Args:
        db_query: sqlalchemy Query object that defines the rows to delete
        retry: Attempt retry in case of deadlock errors due to parallel runs of tests

    Returns:
        Number of deleted rows.

    """
    try:
        with Database(DatabaseType.internal).transaction_context() as session:
            return int(
                db_query.with_session(session).delete(
                    synchronize_session=False))  # type: ignore
    except DBAPIError as e:  # pragma: no cover
        if retry:  # Retry in case of deadlock error from the database when running tests in parallel
            logger.warning(f"Retrying error: {e}")
            return delete_test_data(db_query, retry=False)
        else:
            raise
Example #28
0
def test_defined_tables_dsx_write_database() -> None:
    assert Database(DatabaseType.dsx_write).get_defined_table_names() == [
        "ml_dsx_write.STG_Import_Periodic_ML",
    ]
Example #29
0
def data_output() -> DataOutput:
    internal_database = Database(DatabaseType.internal)
    dsx_write_database = Database(DatabaseType.dsx_write)
    internal_database.is_disabled = lambda: True  # type: ignore
    return DataOutput(RUNTIME_CONFIG, internal_database, dsx_write_database)
Example #30
0
def data_loader() -> DataLoader:
    internal_database = Database(DatabaseType.internal)
    dsx_read_database = Database(DatabaseType.dsx_read)
    internal_database.is_disabled = lambda: True  # type: ignore
    return DataLoader(internal_database, dsx_read_database)