예제 #1
0
def test_get_ml_dag_id_insert_dag_run_and_training_dag_run(
        ml_training_dag_repository: MLTrainingDagRepository, reset_db: None):
    """ Test creation of both MLDagTable and MLTrainingDagTable """
    ml_dag = MLDagRepository().table.select().where(
        and_(MLDagRepository.table.c.id == INSERTED_ML_DAG_ID,
             MLDagRepository.table.c.parameter_1 ==
             PARAMETER_1)).execute().first()
    assert ml_dag is None

    ml_training_dag = ml_training_dag_repository.table.select().where(
        ml_training_dag_repository.table.c.parameter_2 ==
        PARAMETER_2).execute().first()
    assert ml_training_dag is None

    MLTrainingDag._get_ml_dag_id(
        engine=ml_training_dag_repository.metadata.bind,
        dag_run=Conf(conf={
            'parameter_1': PARAMETER_1,
            'parameter_2': PARAMETER_2
        }))

    ml_dag = MLDagRepository().table.select().where(
        and_(MLDagRepository.table.c.id == INSERTED_ML_DAG_ID,
             MLDagRepository.table.c.parameter_1 ==
             PARAMETER_1)).execute().first()
    assert ml_dag is not None

    ml_training_dag = ml_training_dag_repository.table.select().where(
        ml_training_dag_repository.table.c.parameter_2 ==
        PARAMETER_2).execute().first()
    assert ml_training_dag is not None
예제 #2
0
def test_find_by_parameter_1(ml_dag_repository: MLDagRepository, reset_db: None):
    ml_dag_repository.table.insert().values(parameter_1=ML_DAG_1.parameter_1,
                                            datetime_created=datetime.datetime.utcnow()).execute()

    ml_dag_tuple = ml_dag_repository.find_by_parameter_1(parameter_1=ML_DAG_1.parameter_1)

    assert ml_dag_tuple.parameter_1 == ML_DAG_1.parameter_1

    # Should raise exception for NON_EXISTENT_PARAMETER_1
    with pytest.raises(DBException) as e_info:
        ml_dag_repository.find_by_parameter_1(parameter_1=NON_EXISTENT_PARAMETER_1)
    assert NON_EXISTENT_PARAMETER_1 in str(e_info.value)
예제 #3
0
def test_find_by_id(ml_dag_repository: MLDagRepository, reset_db: None):
    ml_dag_repository.table.insert().values(id=ML_DAG_1.id,
                                            parameter_1=ML_DAG_1.parameter_1,
                                            datetime_created=datetime.datetime.utcnow()).execute()

    ml_dag_tuple = ml_dag_repository.find_by_id(id=ML_DAG_1.id)

    assert ml_dag_tuple.id == ML_DAG_1.id

    # Should raise exception for NON_EXISTENT_ML_DAG_ID
    with pytest.raises(DBException) as e_info:
        ml_dag_repository.find_by_id(id=NON_EXISTENT_ML_DAG_ID)
    assert str(NON_EXISTENT_ML_DAG_ID) in str(e_info.value)
def reset_db(ml_testing_dag_repository: MLTestingDagRepository) -> None:
    """ Resets DB before each test to initial testing state """
    ml_testing_dag_repository.metadata.drop_all()
    ml_testing_dag_repository.metadata.create_all()

    MLDagRepository().table.insert().values(id=INSERTED_ML_DAG_ID,
                                            parameter_1=PARAMETER_1).execute()
예제 #5
0
    def finish_task(self, ml_dag_id: int) -> None:
        """ Finishes the task by writing datetime_finished field in db. """
        MLDagRepository().check_ml_dag_id(ml_dag_id=ml_dag_id)
        self._check_task_with_ml_dag_id(ml_dag_id=ml_dag_id)

        self.table.update().where(self.table.c.ml_dag_id == ml_dag_id).values(
            datetime_finished=datetime.datetime.utcnow()).execute()
예제 #6
0
def ml_dag_repository() -> MLDagRepository:
    """ Fixture that makes MLDagTable using local sqlite db """
    engine = db_utils.create_db_engine(login=None,
                                       password=None,
                                       host=DB_NAME,
                                       schema=None,
                                       conn_type='sqlite')

    return MLDagRepository(engine=engine)
예제 #7
0
def reset_db(common_task_1_repository: CommonTask1Repository) -> None:
    """ Fixture that make another dag_run and bed_to_fa for it """
    common_task_1_repository.metadata.drop_all()
    common_task_1_repository.metadata.create_all()

    MLDagRepository().table.insert().values(
        id=INSERTED_ML_DAG_ID, parameter_1='test_parameter_1').execute()
    common_task_1_repository.table.insert().values(
        ml_dag_id=INSERTED_ML_DAG_ID).execute()
예제 #8
0
def reset_db(common_task_1_repository: CommonTask1Repository) -> None:
    """ Resets DB before each test to initial testing state """
    common_task_1_repository.metadata.drop_all()
    common_task_1_repository.metadata.create_all()

    MLDagRepository().table.insert().values(
        id=INSERTED_ML_DAG_ID, parameter_1='test_parameter_1').execute()

    common_task_1_repository.table.insert().values(
        ml_dag_id=INSERTED_ML_DAG_ID).execute()
예제 #9
0
    def is_task_finished(self, ml_dag_id: int) -> bool:
        """ Checks if task is finished, based on the value of datetime_finished field. """
        MLDagRepository().check_ml_dag_id(ml_dag_id=ml_dag_id)
        self._check_task_with_ml_dag_id(ml_dag_id=ml_dag_id)

        datetime_finished = self.table.select().where(
            self.table.c.ml_dag_id ==
            ml_dag_id).execute().first().datetime_finished

        return True if datetime_finished else False
예제 #10
0
def test_insert_dag(ml_dag_repository: MLDagRepository, reset_db: None):
    ml_dag = ml_dag_repository.table.select().where(
        ml_dag_repository.table.c.parameter_1 == ML_DAG_1.parameter_1).execute().first()
    assert ml_dag is None

    dag_tuple = ml_dag_repository.save(
        MLDagRow(id=None, parameter_1=ML_DAG_1.parameter_1))
    assert dag_tuple == ML_DAG_1

    ml_dag = ml_dag_repository.table.select().where(
        ml_dag_repository.table.c.parameter_1 == ML_DAG_1.parameter_1).execute().first()
    assert ml_dag is not None
예제 #11
0
    def _get_ml_dag_id(engine: sqlalchemy.engine.Engine, **kwargs) -> int:
        parameter_1 = kwargs['dag_run'].conf['parameter_1']
        parameter_3 = kwargs['dag_run'].conf['parameter_3']

        # Get ml_testing_dag for parameter_1 and parameter_3 if exists,
        # or insert new ml_dag (if it doesnt exist for parameter_1) and ml_testing_dag
        try:
            ml_testing_dag = MLTestingDagRepository(
                engine=engine).find_by_parameters(parameter_1=parameter_1,
                                                  parameter_3=parameter_3)
        except DBException:
            try:
                ml_dag = MLDagRepository(engine=engine).find_by_parameter_1(
                    parameter_1=parameter_1)
            except DBException:
                ml_dag = MLDagRepository(engine=engine).save(
                    MLDagRow(id=None, parameter_1=parameter_1))

            ml_testing_dag = MLTestingDagRepository(engine=engine).save(
                MLTestingDagRow(id=None,
                                ml_dag=ml_dag,
                                parameter_3=parameter_3))

        return ml_testing_dag.ml_dag.id
예제 #12
0
def test_get_ml_dag_id(ml_training_dag_repository: MLTrainingDagRepository,
                       reset_db: None):
    MLDagRepository().table.insert().values(id=INSERTED_ML_DAG_ID,
                                            parameter_1=PARAMETER_1).execute()

    ml_training_dag_repository.table.insert().values(
        id=INSERTED_ML_TESTING_DAG_ID,
        ml_dag_id=INSERTED_ML_DAG_ID,
        parameter_2=PARAMETER_2,
        datetime_created=datetime.datetime.utcnow()).execute()

    ml_dag_id = MLTrainingDag._get_ml_dag_id(
        engine=ml_training_dag_repository.metadata.bind,
        dag_run=Conf(conf={
            'parameter_1': PARAMETER_1,
            'parameter_2': PARAMETER_2
        }))

    assert ml_dag_id == INSERTED_ML_DAG_ID
예제 #13
0
def test_get_ml_dag_id_insert_training_dag_run(
        ml_training_dag_repository: MLTrainingDagRepository, reset_db: None):
    """ Test creation of MLTrainingDagTable for existing MLDagTable """
    MLDagRepository().table.insert().values(id=INSERTED_ML_DAG_ID,
                                            parameter_1=PARAMETER_1).execute()

    ml_training_dag = ml_training_dag_repository.table.select().where(
        ml_training_dag_repository.table.c.parameter_2 ==
        PARAMETER_2).execute().first()
    assert ml_training_dag is None

    MLTrainingDag._get_ml_dag_id(
        engine=ml_training_dag_repository.metadata.bind,
        dag_run=Conf(conf={
            'parameter_1': PARAMETER_1,
            'parameter_2': PARAMETER_2
        }))

    ml_training_dag = ml_training_dag_repository.table.select().where(
        ml_training_dag_repository.table.c.parameter_2 ==
        PARAMETER_2).execute().first()
    assert ml_training_dag is not None
예제 #14
0
def test_check_ml_dag_id(ml_dag_repository: MLDagRepository, reset_db: None):
    # Should raise exception for NON_EXISTENT_ML_DAG_ID
    with pytest.raises(DBException) as e_info:
        assert ml_dag_repository.check_ml_dag_id(ml_dag_id=NON_EXISTENT_ML_DAG_ID)
    assert str(NON_EXISTENT_ML_DAG_ID) in str(e_info.value)