def test_get_ml_dag_id_insert_dag_run_and_training_dag_run( ml_training_dag_repository: MLTrainingDagRepository, reset_db: None): """ Test creation of both MLDagTable and MLTrainingDagTable """ ml_dag = MLDagRepository().table.select().where( and_(MLDagRepository.table.c.id == INSERTED_ML_DAG_ID, MLDagRepository.table.c.parameter_1 == PARAMETER_1)).execute().first() assert ml_dag is None ml_training_dag = ml_training_dag_repository.table.select().where( ml_training_dag_repository.table.c.parameter_2 == PARAMETER_2).execute().first() assert ml_training_dag is None MLTrainingDag._get_ml_dag_id( engine=ml_training_dag_repository.metadata.bind, dag_run=Conf(conf={ 'parameter_1': PARAMETER_1, 'parameter_2': PARAMETER_2 })) ml_dag = MLDagRepository().table.select().where( and_(MLDagRepository.table.c.id == INSERTED_ML_DAG_ID, MLDagRepository.table.c.parameter_1 == PARAMETER_1)).execute().first() assert ml_dag is not None ml_training_dag = ml_training_dag_repository.table.select().where( ml_training_dag_repository.table.c.parameter_2 == PARAMETER_2).execute().first() assert ml_training_dag is not None
def test_find_by_parameter_1(ml_dag_repository: MLDagRepository, reset_db: None): ml_dag_repository.table.insert().values(parameter_1=ML_DAG_1.parameter_1, datetime_created=datetime.datetime.utcnow()).execute() ml_dag_tuple = ml_dag_repository.find_by_parameter_1(parameter_1=ML_DAG_1.parameter_1) assert ml_dag_tuple.parameter_1 == ML_DAG_1.parameter_1 # Should raise exception for NON_EXISTENT_PARAMETER_1 with pytest.raises(DBException) as e_info: ml_dag_repository.find_by_parameter_1(parameter_1=NON_EXISTENT_PARAMETER_1) assert NON_EXISTENT_PARAMETER_1 in str(e_info.value)
def test_find_by_id(ml_dag_repository: MLDagRepository, reset_db: None): ml_dag_repository.table.insert().values(id=ML_DAG_1.id, parameter_1=ML_DAG_1.parameter_1, datetime_created=datetime.datetime.utcnow()).execute() ml_dag_tuple = ml_dag_repository.find_by_id(id=ML_DAG_1.id) assert ml_dag_tuple.id == ML_DAG_1.id # Should raise exception for NON_EXISTENT_ML_DAG_ID with pytest.raises(DBException) as e_info: ml_dag_repository.find_by_id(id=NON_EXISTENT_ML_DAG_ID) assert str(NON_EXISTENT_ML_DAG_ID) in str(e_info.value)
def reset_db(ml_testing_dag_repository: MLTestingDagRepository) -> None: """ Resets DB before each test to initial testing state """ ml_testing_dag_repository.metadata.drop_all() ml_testing_dag_repository.metadata.create_all() MLDagRepository().table.insert().values(id=INSERTED_ML_DAG_ID, parameter_1=PARAMETER_1).execute()
def finish_task(self, ml_dag_id: int) -> None: """ Finishes the task by writing datetime_finished field in db. """ MLDagRepository().check_ml_dag_id(ml_dag_id=ml_dag_id) self._check_task_with_ml_dag_id(ml_dag_id=ml_dag_id) self.table.update().where(self.table.c.ml_dag_id == ml_dag_id).values( datetime_finished=datetime.datetime.utcnow()).execute()
def ml_dag_repository() -> MLDagRepository: """ Fixture that makes MLDagTable using local sqlite db """ engine = db_utils.create_db_engine(login=None, password=None, host=DB_NAME, schema=None, conn_type='sqlite') return MLDagRepository(engine=engine)
def reset_db(common_task_1_repository: CommonTask1Repository) -> None: """ Fixture that make another dag_run and bed_to_fa for it """ common_task_1_repository.metadata.drop_all() common_task_1_repository.metadata.create_all() MLDagRepository().table.insert().values( id=INSERTED_ML_DAG_ID, parameter_1='test_parameter_1').execute() common_task_1_repository.table.insert().values( ml_dag_id=INSERTED_ML_DAG_ID).execute()
def reset_db(common_task_1_repository: CommonTask1Repository) -> None: """ Resets DB before each test to initial testing state """ common_task_1_repository.metadata.drop_all() common_task_1_repository.metadata.create_all() MLDagRepository().table.insert().values( id=INSERTED_ML_DAG_ID, parameter_1='test_parameter_1').execute() common_task_1_repository.table.insert().values( ml_dag_id=INSERTED_ML_DAG_ID).execute()
def is_task_finished(self, ml_dag_id: int) -> bool: """ Checks if task is finished, based on the value of datetime_finished field. """ MLDagRepository().check_ml_dag_id(ml_dag_id=ml_dag_id) self._check_task_with_ml_dag_id(ml_dag_id=ml_dag_id) datetime_finished = self.table.select().where( self.table.c.ml_dag_id == ml_dag_id).execute().first().datetime_finished return True if datetime_finished else False
def test_insert_dag(ml_dag_repository: MLDagRepository, reset_db: None): ml_dag = ml_dag_repository.table.select().where( ml_dag_repository.table.c.parameter_1 == ML_DAG_1.parameter_1).execute().first() assert ml_dag is None dag_tuple = ml_dag_repository.save( MLDagRow(id=None, parameter_1=ML_DAG_1.parameter_1)) assert dag_tuple == ML_DAG_1 ml_dag = ml_dag_repository.table.select().where( ml_dag_repository.table.c.parameter_1 == ML_DAG_1.parameter_1).execute().first() assert ml_dag is not None
def _get_ml_dag_id(engine: sqlalchemy.engine.Engine, **kwargs) -> int: parameter_1 = kwargs['dag_run'].conf['parameter_1'] parameter_3 = kwargs['dag_run'].conf['parameter_3'] # Get ml_testing_dag for parameter_1 and parameter_3 if exists, # or insert new ml_dag (if it doesnt exist for parameter_1) and ml_testing_dag try: ml_testing_dag = MLTestingDagRepository( engine=engine).find_by_parameters(parameter_1=parameter_1, parameter_3=parameter_3) except DBException: try: ml_dag = MLDagRepository(engine=engine).find_by_parameter_1( parameter_1=parameter_1) except DBException: ml_dag = MLDagRepository(engine=engine).save( MLDagRow(id=None, parameter_1=parameter_1)) ml_testing_dag = MLTestingDagRepository(engine=engine).save( MLTestingDagRow(id=None, ml_dag=ml_dag, parameter_3=parameter_3)) return ml_testing_dag.ml_dag.id
def test_get_ml_dag_id(ml_training_dag_repository: MLTrainingDagRepository, reset_db: None): MLDagRepository().table.insert().values(id=INSERTED_ML_DAG_ID, parameter_1=PARAMETER_1).execute() ml_training_dag_repository.table.insert().values( id=INSERTED_ML_TESTING_DAG_ID, ml_dag_id=INSERTED_ML_DAG_ID, parameter_2=PARAMETER_2, datetime_created=datetime.datetime.utcnow()).execute() ml_dag_id = MLTrainingDag._get_ml_dag_id( engine=ml_training_dag_repository.metadata.bind, dag_run=Conf(conf={ 'parameter_1': PARAMETER_1, 'parameter_2': PARAMETER_2 })) assert ml_dag_id == INSERTED_ML_DAG_ID
def test_get_ml_dag_id_insert_training_dag_run( ml_training_dag_repository: MLTrainingDagRepository, reset_db: None): """ Test creation of MLTrainingDagTable for existing MLDagTable """ MLDagRepository().table.insert().values(id=INSERTED_ML_DAG_ID, parameter_1=PARAMETER_1).execute() ml_training_dag = ml_training_dag_repository.table.select().where( ml_training_dag_repository.table.c.parameter_2 == PARAMETER_2).execute().first() assert ml_training_dag is None MLTrainingDag._get_ml_dag_id( engine=ml_training_dag_repository.metadata.bind, dag_run=Conf(conf={ 'parameter_1': PARAMETER_1, 'parameter_2': PARAMETER_2 })) ml_training_dag = ml_training_dag_repository.table.select().where( ml_training_dag_repository.table.c.parameter_2 == PARAMETER_2).execute().first() assert ml_training_dag is not None
def test_check_ml_dag_id(ml_dag_repository: MLDagRepository, reset_db: None): # Should raise exception for NON_EXISTENT_ML_DAG_ID with pytest.raises(DBException) as e_info: assert ml_dag_repository.check_ml_dag_id(ml_dag_id=NON_EXISTENT_ML_DAG_ID) assert str(NON_EXISTENT_ML_DAG_ID) in str(e_info.value)