def prepare_database() -> Iterator[None]:
    internal_database = Database(DatabaseType.internal)
    with internal_database.transaction_context() as session:
        model_run_with_one_result_row = ForecastModelRun(
            model_name="model_run_with_one_forecast_data_row",
            start=FIXED_TIMESTAMP,
            end=FIXED_TIMESTAMP,
            status=ForecastModelRunStatus.COMPLETED,
        )
        model_run_with_zero_result_rows = ForecastModelRun(
            model_name="model_run_with_zero_forecast_data_rows",
            start=FIXED_TIMESTAMP,
            end=FIXED_TIMESTAMP,
            status=ForecastModelRunStatus.COMPLETED,
        )
        forecast_run = ForecastRun(
            start=FIXED_TIMESTAMP,
            end=FIXED_TIMESTAMP,
            run_type=EngineRunType.production,
            includes_cleaning=True,
            status=ForecastRunStatus.COMPLETED,
            forecast_periods=1,
            prediction_start_month=FIXED_TIMESTAMP,
            model_runs=[
                model_run_with_one_result_row, model_run_with_zero_result_rows
            ],  # type: ignore
        )

        session.add(forecast_run)
        session.flush()
        test_model_run_id = model_run_with_one_result_row.id

        forecast_data = {
            "model_run_id": test_model_run_id,
            "Contract_ID": "test_contract_for_compare_structure_database",
            "Item_ID": -1,
            "Prediction_Start_Month": 20200101,
            "Predicted_Month": 20200101,
            "Prediction_Months_Delta": 0,
            "Prediction_Raw": 0,
            "Prediction_Post": 0,
            "Actual": None,
            "Accuracy": None,
        }

        session.execute(ForecastData.insert().values(forecast_data))

    yield

    # cleanup
    with internal_database.transaction_context() as session:
        forecast_run = session.merge(forecast_run)
        # following delete will cascade and delete also forecast_model_runs due to FK relationship
        session.delete(forecast_run)  # type: ignore
        assert (delete_test_data(
            Query(ForecastData).filter(ForecastData.c.model_run_id ==
                                       test_model_run_id)  # type: ignore
        ) == 1), "Cleanup failed, check database for uncleaned data"
Пример #2
0
def update_forecast_data_with_cleaned_data_sales(
        internal_database: Database, cleaned_data_run_id: Optional[int],
        cleaned_data_newest_month: Optional[datetime]) -> None:
    """Update actual values for all previous forecasts with newest cleaned data up to newest date in cleaned data.

    Actual values with NULL are set to zero if date is before or equal to the newest date in cleaning data.
    Any dates after newest date are left unchanged.

    Args:
        internal_database: Service to access internal database.
        cleaned_data_run_id: ID of latest successful platform run, which included cleaning.
        cleaned_data_newest_month: Newest month in cleaned data table associated with the ``cleaned_data_run_id``.
    """
    if internal_database.is_disabled():
        logger.warning(
            "Skipping update of previous forecasts due to disabled database")
        return

    assert cleaned_data_run_id is not None, "Invalid program state: Expected cleaned_data_run_id to exist"
    with internal_database.transaction_context() as session:
        updated_existing = session.execute(ForecastData.update().where(
            CleanedData.c.run_id == cleaned_data_run_id).where(
                ForecastData.c.Contract_ID == CleanedData.c.Contract_ID).where(
                    ForecastData.c.Item_ID == CleanedData.c.Item_ID).where(
                        ForecastData.c.Predicted_Month ==
                        CleanedData.c.Date_YYYYMM).values({
                            "Actual":
                            CleanedData.c.Order_Quantity,
                            "Accuracy":
                            compute_accuracy_as_sql(
                                CleanedData.c.Order_Quantity,
                                ForecastData.c.Prediction_Post),
                        }))

    logger.info(
        f"Updated {updated_existing.rowcount} rows of forecast_data with old actual values to"
        f" newest actual values from cleaned_data")

    assert cleaned_data_newest_month, "Invalid program state: Expected cleaned_data_newest_month to exist"
    newest_month = int(
        cleaned_data_newest_month.strftime(PREDICTION_MONTH_FORMAT))
    with internal_database.transaction_context() as session:
        updated_nulls = session.execute(ForecastData.update().where(
            ForecastData.c.Predicted_Month <= newest_month).where(
                ForecastData.c.Actual == None)  # noqa: E711
                                        .values({
                                            "Actual":
                                            0,
                                            "Accuracy":
                                            compute_accuracy_as_sql(
                                                0,
                                                ForecastData.c.Prediction_Post)
                                        }))

    logger.info(
        f"Updated {updated_nulls.rowcount} rows of forecast_data without actual values to"
        f" newest actual values from cleaned_data")
Пример #3
0
def test_ensure_schema_exists_creates_missing_schema(
        monkeypatch: MonkeyPatch, caplog: LogCaptureFixture) -> None:
    new_schema = "new_schema"
    mock_get_schema_names = Mock(return_value=["old_schema"])

    internal_database = Database(DatabaseType.internal)
    internal_database._database_schema = new_schema

    with internal_database.transaction_context() as session:
        original_connection = session.connection()

        def mock_execute_create_schema(
                connection: Connection,
                statement: Executable) -> Optional[Connection]:
            if isinstance(statement, CreateSchema):
                assert statement.element == new_schema
                return None

            return cast(Connection, original_connection(connection, statement))

    monkeypatch.setattr(MSDialect_pyodbc, "get_schema_names",
                        mock_get_schema_names)
    monkeypatch.setattr(Connection, "execute", mock_execute_create_schema)

    with caplog.at_level(logging.INFO):
        ensure_schema_exists(internal_database)
        assert f"Creating schema: {new_schema} in internal database" in caplog.messages

    mock_get_schema_names.assert_called_once()
Пример #4
0
def get_last_successful_cleaning_run_id(internal_database: Database) -> int:
    """Returns last successful and completed cleaning run ID.

    Args:
        internal_database: Service to access internal database.
    """
    with internal_database.transaction_context() as session:
        cleaning_run_id = (
            session.query(func.max(ForecastRun.id))  # type: ignore
            .filter(ForecastRun.status == ForecastRunStatus.COMPLETED).filter(
                ForecastRun.includes_cleaning == True).scalar())
        logger.debug(
            f"Found highest completed cleaning run ID: {cleaning_run_id}")

        cleaned_data_ids = session.query(distinct(
            CleanedData.c.run_id)).all()  # type: ignore
        logger.debug(
            f"Found cleaned data IDs in internal database: {cleaned_data_ids}")

        if not cleaning_run_id or len(
                cleaned_data_ids
        ) != 1 or cleaned_data_ids[0][0] != cleaning_run_id:
            raise DataException(
                "Cannot determine valid cleaning data, "
                "please re-run cleaning steps, "
                'e.g. by providing the "--force-reload" parameter')
        return int(cleaning_run_id)
def assert_forecast_data_for_model_run(internal_database: Database,
                                       model_run_id: int,
                                       run: ForecastRun) -> None:
    """Assert count of rows in forecast data table of internal database against expected structure."""
    with internal_database.transaction_context() as session:
        model_name = (
            session.query(ForecastModelRun)  # type: ignore
            .filter(ForecastModelRun.id == model_run_id).first().model_name)

    run_parameters = ExpectedForecastStructureParameters(
        account_name=model_name,
        forecast_periods=run.forecast_periods,
        prediction_month=run.prediction_start_month,
    )
    expected_forecast_structure = get_expected_forecast_structure(
        run_parameters)
    expected_number_of_rows = expected_forecast_structure.shape[0]
    forecast_data_count = get_forecast_data_count_for_model_run(
        internal_database, model_run_id)

    assert forecast_data_count == expected_number_of_rows, (
        f"Forecast data count ({forecast_data_count}) for {model_name} (model_run={model_run_id}) "
        f"does not match the expectation ({expected_number_of_rows}) defined by {run_parameters})"
    )
    logger.info(
        f"Asserted forecast data count ({forecast_data_count}) for {model_name} (model_run={model_run_id})"
    )
def assert_dsx_output_count(dsx_write_database: Database,
                            expected_count: int) -> None:
    """Assert count of of values saved in dsx write database."""
    with dsx_write_database.transaction_context() as session:
        dsx_data_count = session.query(DsxOutput).count()  # type: ignore

    assert dsx_data_count == expected_count, (
        f"Total count of  values ({dsx_data_count}) in {DsxOutput.name} table of {dsx_write_database} "
        f"does not match the expectation ({expected_count})")
    logger.info(
        f"Asserted total count of values ({expected_count}) in {DsxOutput.name} "
        f"of {dsx_write_database}")
def get_forecast_data_count_for_model_run(internal_database: Database,
                                          model_run_id: int) -> int:
    """Return count of forecast_data entries associated with model run."""
    with internal_database.transaction_context() as session:
        forecast_data_count = (
            session.query(ForecastData)  # type: ignore
            .filter(ForecastData.c.model_run_id == model_run_id).count())

    logger.debug(
        f"Found {forecast_data_count} forecast data rows associated with forecast model run: {model_run_id}"
    )
    return cast(int, forecast_data_count)
Пример #8
0
def get_newest_cleaned_data_month(internal_database: Database,
                                  cleaned_data_run_id: int) -> datetime:
    """Returns most recent month from cleaned data table of internal database.

    Args:
        internal_database: Service to access internal database.
        cleaned_data_run_id: ID of latest successful platform run, which included cleaning.
    """
    with internal_database.transaction_context() as session:
        return cast(
            datetime,
            session.query(func.max(CleanedData.c.Date))  # type: ignore
            .filter(CleanedData.c.run_id == cleaned_data_run_id).scalar(),
        )
def get_last_successful_production_run(
        internal_database: Database) -> ForecastRun:
    """Return last successful production run as detached object."""
    with internal_database.transaction_context() as session:
        session.expire_on_commit = False
        forecast_run = (
            session.query(ForecastRun)  # type: ignore
            .filter(ForecastRun.status == ForecastRunStatus.COMPLETED).filter(
                ForecastRun.run_type == EngineRunType.production).order_by(
                    desc(ForecastRun.id)).limit(1).first())

    logger.debug(
        f"Found last completed production forecast run: {forecast_run}")
    return cast(ForecastRun, forecast_run)
def assert_dsx_output_total_sum(dsx_write_database: Database,
                                expected_total_sum: int) -> None:
    """Assert count of values in dsx write database."""
    with dsx_write_database.transaction_context() as session:
        dsx_total_sum = sum([
            int(result.Value)
            for result in session.query(DsxOutput.c.Value).all()
        ])  # type: ignore

    assert dsx_total_sum == expected_total_sum, (
        f"Total sum of values ({dsx_total_sum}) in {DsxOutput.name} table of {dsx_write_database} "
        f"does not match the expectation ({expected_total_sum})")
    logger.info(
        f"Asserted total sum of values ({expected_total_sum}) in {DsxOutput.name} "
        f"of {dsx_write_database}")
def get_model_run_ids_for_forecast_run(internal_database: Database,
                                       forecast_run_id: int) -> List[int]:
    """Return successfully completed model run IDs associated with the forecast run."""
    with internal_database.transaction_context() as session:
        model_run_ids = [
            result.id for result in (
                session.query(ForecastModelRun)  # type: ignore
                .filter(ForecastModelRun.run_id == forecast_run_id).filter(
                    ForecastModelRun.status ==
                    ForecastModelRunStatus.COMPLETED).all())
        ]
    logger.debug(
        f"Found {len(model_run_ids)} completed model run(s) (id={model_run_ids}) "
        f"associated with forecast run ID: {forecast_run_id}")
    return model_run_ids
Пример #12
0
def ensure_schema_exists(database: Database) -> None:
    """Ensure database schema configured to be used with database exists and create it if missing."""
    if database.is_disabled():
        logger.error(
            f"Cannot setup schema, because {database} connection is not available"
        )
        return

    with database.transaction_context() as session:
        connection = session.connection()
        schema_name = database._database_schema
        existing_schemas = connection.dialect.get_schema_names(connection)

        if schema_name in existing_schemas:
            logger.info(f"Found schema: {schema_name} in {database}")
            return

        logger.info(f"Creating schema: {schema_name} in {database}")
        connection.execute(schema.CreateSchema(schema_name))