def prepare_database() -> Iterator[None]: internal_database = Database(DatabaseType.internal) with internal_database.transaction_context() as session: model_run_with_one_result_row = ForecastModelRun( model_name="model_run_with_one_forecast_data_row", start=FIXED_TIMESTAMP, end=FIXED_TIMESTAMP, status=ForecastModelRunStatus.COMPLETED, ) model_run_with_zero_result_rows = ForecastModelRun( model_name="model_run_with_zero_forecast_data_rows", start=FIXED_TIMESTAMP, end=FIXED_TIMESTAMP, status=ForecastModelRunStatus.COMPLETED, ) forecast_run = ForecastRun( start=FIXED_TIMESTAMP, end=FIXED_TIMESTAMP, run_type=EngineRunType.production, includes_cleaning=True, status=ForecastRunStatus.COMPLETED, forecast_periods=1, prediction_start_month=FIXED_TIMESTAMP, model_runs=[ model_run_with_one_result_row, model_run_with_zero_result_rows ], # type: ignore ) session.add(forecast_run) session.flush() test_model_run_id = model_run_with_one_result_row.id forecast_data = { "model_run_id": test_model_run_id, "Contract_ID": "test_contract_for_compare_structure_database", "Item_ID": -1, "Prediction_Start_Month": 20200101, "Predicted_Month": 20200101, "Prediction_Months_Delta": 0, "Prediction_Raw": 0, "Prediction_Post": 0, "Actual": None, "Accuracy": None, } session.execute(ForecastData.insert().values(forecast_data)) yield # cleanup with internal_database.transaction_context() as session: forecast_run = session.merge(forecast_run) # following delete will cascade and delete also forecast_model_runs due to FK relationship session.delete(forecast_run) # type: ignore assert (delete_test_data( Query(ForecastData).filter(ForecastData.c.model_run_id == test_model_run_id) # type: ignore ) == 1), "Cleanup failed, check database for uncleaned data"
def update_forecast_data_with_cleaned_data_sales( internal_database: Database, cleaned_data_run_id: Optional[int], cleaned_data_newest_month: Optional[datetime]) -> None: """Update actual values for all previous forecasts with newest cleaned data up to newest date in cleaned data. Actual values with NULL are set to zero if date is before or equal to the newest date in cleaning data. Any dates after newest date are left unchanged. Args: internal_database: Service to access internal database. cleaned_data_run_id: ID of latest successful platform run, which included cleaning. cleaned_data_newest_month: Newest month in cleaned data table associated with the ``cleaned_data_run_id``. """ if internal_database.is_disabled(): logger.warning( "Skipping update of previous forecasts due to disabled database") return assert cleaned_data_run_id is not None, "Invalid program state: Expected cleaned_data_run_id to exist" with internal_database.transaction_context() as session: updated_existing = session.execute(ForecastData.update().where( CleanedData.c.run_id == cleaned_data_run_id).where( ForecastData.c.Contract_ID == CleanedData.c.Contract_ID).where( ForecastData.c.Item_ID == CleanedData.c.Item_ID).where( ForecastData.c.Predicted_Month == CleanedData.c.Date_YYYYMM).values({ "Actual": CleanedData.c.Order_Quantity, "Accuracy": compute_accuracy_as_sql( CleanedData.c.Order_Quantity, ForecastData.c.Prediction_Post), })) logger.info( f"Updated {updated_existing.rowcount} rows of forecast_data with old actual values to" f" newest actual values from cleaned_data") assert cleaned_data_newest_month, "Invalid program state: Expected cleaned_data_newest_month to exist" newest_month = int( cleaned_data_newest_month.strftime(PREDICTION_MONTH_FORMAT)) with internal_database.transaction_context() as session: updated_nulls = session.execute(ForecastData.update().where( ForecastData.c.Predicted_Month <= newest_month).where( ForecastData.c.Actual == None) # noqa: E711 .values({ "Actual": 0, "Accuracy": compute_accuracy_as_sql( 0, ForecastData.c.Prediction_Post) })) logger.info( f"Updated {updated_nulls.rowcount} rows of forecast_data without actual values to" f" newest actual values from cleaned_data")
def test_ensure_schema_exists_creates_missing_schema( monkeypatch: MonkeyPatch, caplog: LogCaptureFixture) -> None: new_schema = "new_schema" mock_get_schema_names = Mock(return_value=["old_schema"]) internal_database = Database(DatabaseType.internal) internal_database._database_schema = new_schema with internal_database.transaction_context() as session: original_connection = session.connection() def mock_execute_create_schema( connection: Connection, statement: Executable) -> Optional[Connection]: if isinstance(statement, CreateSchema): assert statement.element == new_schema return None return cast(Connection, original_connection(connection, statement)) monkeypatch.setattr(MSDialect_pyodbc, "get_schema_names", mock_get_schema_names) monkeypatch.setattr(Connection, "execute", mock_execute_create_schema) with caplog.at_level(logging.INFO): ensure_schema_exists(internal_database) assert f"Creating schema: {new_schema} in internal database" in caplog.messages mock_get_schema_names.assert_called_once()
def get_last_successful_cleaning_run_id(internal_database: Database) -> int: """Returns last successful and completed cleaning run ID. Args: internal_database: Service to access internal database. """ with internal_database.transaction_context() as session: cleaning_run_id = ( session.query(func.max(ForecastRun.id)) # type: ignore .filter(ForecastRun.status == ForecastRunStatus.COMPLETED).filter( ForecastRun.includes_cleaning == True).scalar()) logger.debug( f"Found highest completed cleaning run ID: {cleaning_run_id}") cleaned_data_ids = session.query(distinct( CleanedData.c.run_id)).all() # type: ignore logger.debug( f"Found cleaned data IDs in internal database: {cleaned_data_ids}") if not cleaning_run_id or len( cleaned_data_ids ) != 1 or cleaned_data_ids[0][0] != cleaning_run_id: raise DataException( "Cannot determine valid cleaning data, " "please re-run cleaning steps, " 'e.g. by providing the "--force-reload" parameter') return int(cleaning_run_id)
def assert_forecast_data_for_model_run(internal_database: Database, model_run_id: int, run: ForecastRun) -> None: """Assert count of rows in forecast data table of internal database against expected structure.""" with internal_database.transaction_context() as session: model_name = ( session.query(ForecastModelRun) # type: ignore .filter(ForecastModelRun.id == model_run_id).first().model_name) run_parameters = ExpectedForecastStructureParameters( account_name=model_name, forecast_periods=run.forecast_periods, prediction_month=run.prediction_start_month, ) expected_forecast_structure = get_expected_forecast_structure( run_parameters) expected_number_of_rows = expected_forecast_structure.shape[0] forecast_data_count = get_forecast_data_count_for_model_run( internal_database, model_run_id) assert forecast_data_count == expected_number_of_rows, ( f"Forecast data count ({forecast_data_count}) for {model_name} (model_run={model_run_id}) " f"does not match the expectation ({expected_number_of_rows}) defined by {run_parameters})" ) logger.info( f"Asserted forecast data count ({forecast_data_count}) for {model_name} (model_run={model_run_id})" )
def assert_dsx_output_count(dsx_write_database: Database, expected_count: int) -> None: """Assert count of of values saved in dsx write database.""" with dsx_write_database.transaction_context() as session: dsx_data_count = session.query(DsxOutput).count() # type: ignore assert dsx_data_count == expected_count, ( f"Total count of values ({dsx_data_count}) in {DsxOutput.name} table of {dsx_write_database} " f"does not match the expectation ({expected_count})") logger.info( f"Asserted total count of values ({expected_count}) in {DsxOutput.name} " f"of {dsx_write_database}")
def get_forecast_data_count_for_model_run(internal_database: Database, model_run_id: int) -> int: """Return count of forecast_data entries associated with model run.""" with internal_database.transaction_context() as session: forecast_data_count = ( session.query(ForecastData) # type: ignore .filter(ForecastData.c.model_run_id == model_run_id).count()) logger.debug( f"Found {forecast_data_count} forecast data rows associated with forecast model run: {model_run_id}" ) return cast(int, forecast_data_count)
def get_newest_cleaned_data_month(internal_database: Database, cleaned_data_run_id: int) -> datetime: """Returns most recent month from cleaned data table of internal database. Args: internal_database: Service to access internal database. cleaned_data_run_id: ID of latest successful platform run, which included cleaning. """ with internal_database.transaction_context() as session: return cast( datetime, session.query(func.max(CleanedData.c.Date)) # type: ignore .filter(CleanedData.c.run_id == cleaned_data_run_id).scalar(), )
def get_last_successful_production_run( internal_database: Database) -> ForecastRun: """Return last successful production run as detached object.""" with internal_database.transaction_context() as session: session.expire_on_commit = False forecast_run = ( session.query(ForecastRun) # type: ignore .filter(ForecastRun.status == ForecastRunStatus.COMPLETED).filter( ForecastRun.run_type == EngineRunType.production).order_by( desc(ForecastRun.id)).limit(1).first()) logger.debug( f"Found last completed production forecast run: {forecast_run}") return cast(ForecastRun, forecast_run)
def assert_dsx_output_total_sum(dsx_write_database: Database, expected_total_sum: int) -> None: """Assert count of values in dsx write database.""" with dsx_write_database.transaction_context() as session: dsx_total_sum = sum([ int(result.Value) for result in session.query(DsxOutput.c.Value).all() ]) # type: ignore assert dsx_total_sum == expected_total_sum, ( f"Total sum of values ({dsx_total_sum}) in {DsxOutput.name} table of {dsx_write_database} " f"does not match the expectation ({expected_total_sum})") logger.info( f"Asserted total sum of values ({expected_total_sum}) in {DsxOutput.name} " f"of {dsx_write_database}")
def get_model_run_ids_for_forecast_run(internal_database: Database, forecast_run_id: int) -> List[int]: """Return successfully completed model run IDs associated with the forecast run.""" with internal_database.transaction_context() as session: model_run_ids = [ result.id for result in ( session.query(ForecastModelRun) # type: ignore .filter(ForecastModelRun.run_id == forecast_run_id).filter( ForecastModelRun.status == ForecastModelRunStatus.COMPLETED).all()) ] logger.debug( f"Found {len(model_run_ids)} completed model run(s) (id={model_run_ids}) " f"associated with forecast run ID: {forecast_run_id}") return model_run_ids
def ensure_schema_exists(database: Database) -> None: """Ensure database schema configured to be used with database exists and create it if missing.""" if database.is_disabled(): logger.error( f"Cannot setup schema, because {database} connection is not available" ) return with database.transaction_context() as session: connection = session.connection() schema_name = database._database_schema existing_schemas = connection.dialect.get_schema_names(connection) if schema_name in existing_schemas: logger.info(f"Found schema: {schema_name} in {database}") return logger.info(f"Creating schema: {schema_name} in {database}") connection.execute(schema.CreateSchema(schema_name))