def test_save(ml_testing_dag_repository: MLTestingDagRepository,
              reset_db: None):
    ml_testing_dag = ml_testing_dag_repository.table.select().where(
        and_(ml_testing_dag_repository.table.c.ml_dag_id == INSERTED_ML_DAG_ID,
             ml_testing_dag_repository.table.c.parameter_3 ==
             PARAMETER_3)).execute().first()
    assert ml_testing_dag is None

    ml_testing_dag_repository.save(
        MLTestingDagRow(id=None,
                        ml_dag=MLDagRow(id=INSERTED_ML_DAG_ID,
                                        parameter_1=PARAMETER_1),
                        parameter_3=PARAMETER_3))

    ml_testing_dag = ml_testing_dag_repository.table.select().where(
        and_(ml_testing_dag_repository.table.c.ml_dag_id == INSERTED_ML_DAG_ID,
             ml_testing_dag_repository.table.c.parameter_3 ==
             PARAMETER_3)).execute().first()
    assert ml_testing_dag is not None

    # Should raise exception for existing INSERTED_ML_DAG_ID
    with pytest.raises(DBException) as e_info:
        ml_testing_dag_repository.save(
            MLTestingDagRow(id=None,
                            ml_dag=MLDagRow(id=INSERTED_ML_DAG_ID,
                                            parameter_1=PARAMETER_1),
                            parameter_3=PARAMETER_3))
    assert str(INSERTED_ML_DAG_ID) in str(e_info.value)
Exemple #2
0
    def find_by_parameters(self, parameter_1: str,
                           parameter_3: str) -> MLTestingDagRow:
        """ Returns MLTestingDagRow for parameters

        Raises:
            DBException: If ml_testing_dag with parameters does not exist in db

        """
        ml_testing_dag_dag_join = self.table.join(
            MLDagRepository.table).select().where(
                and_(MLDagRepository.table.c.parameter_1 == parameter_1,
                     self.table.c.parameter_3 ==
                     parameter_3)).execute().first()

        if ml_testing_dag_dag_join:
            return MLTestingDagRow(
                id=ml_testing_dag_dag_join[self.table.c.id],
                ml_dag=MLDagRow(
                    id=ml_testing_dag_dag_join[MLDagRepository.table.c.id],
                    parameter_1=ml_testing_dag_dag_join[
                        MLDagRepository.table.c.parameter_1]),
                parameter_3=ml_testing_dag_dag_join[self.table.c.parameter_3])
        else:
            raise DBException(
                f'ml_testing_dag with [parameter_1: {parameter_1}] and '
                f'[parameter_3: {parameter_3}] does not exists')
Exemple #3
0
def test_insert_dag(ml_dag_repository: MLDagRepository, reset_db: None):
    ml_dag = ml_dag_repository.table.select().where(
        ml_dag_repository.table.c.parameter_1 == ML_DAG_1.parameter_1).execute().first()
    assert ml_dag is None

    dag_tuple = ml_dag_repository.save(
        MLDagRow(id=None, parameter_1=ML_DAG_1.parameter_1))
    assert dag_tuple == ML_DAG_1

    ml_dag = ml_dag_repository.table.select().where(
        ml_dag_repository.table.c.parameter_1 == ML_DAG_1.parameter_1).execute().first()
    assert ml_dag is not None
Exemple #4
0
    def _get_ml_dag_id(engine: sqlalchemy.engine.Engine, **kwargs) -> int:
        parameter_1 = kwargs['dag_run'].conf['parameter_1']
        parameter_3 = kwargs['dag_run'].conf['parameter_3']

        # Get ml_testing_dag for parameter_1 and parameter_3 if exists,
        # or insert new ml_dag (if it doesnt exist for parameter_1) and ml_testing_dag
        try:
            ml_testing_dag = MLTestingDagRepository(
                engine=engine).find_by_parameters(parameter_1=parameter_1,
                                                  parameter_3=parameter_3)
        except DBException:
            try:
                ml_dag = MLDagRepository(engine=engine).find_by_parameter_1(
                    parameter_1=parameter_1)
            except DBException:
                ml_dag = MLDagRepository(engine=engine).save(
                    MLDagRow(id=None, parameter_1=parameter_1))

            ml_testing_dag = MLTestingDagRepository(engine=engine).save(
                MLTestingDagRow(id=None,
                                ml_dag=ml_dag,
                                parameter_3=parameter_3))

        return ml_testing_dag.ml_dag.id
Exemple #5
0
import datetime

import pytest

from dags.exceptions.db_exception import DBException
from dags.repositories.ml_dag import MLDagRepository, MLDagRow
from dags.utils import db_utils

ML_DAG_1 = MLDagRow(id=1, parameter_1='test_parameter_1')
NON_EXISTENT_ML_DAG_ID = 100
NON_EXISTENT_PARAMETER_1 = 'non_existent_parameter_1'
DB_NAME = 'test.db'


# Service should be stateless, so widest scope is appropriate
@pytest.fixture(scope='module')
def ml_dag_repository() -> MLDagRepository:
    """ Fixture that makes MLDagTable using local sqlite db """
    engine = db_utils.create_db_engine(login=None,
                                       password=None,
                                       host=DB_NAME,
                                       schema=None,
                                       conn_type='sqlite')

    return MLDagRepository(engine=engine)


@pytest.fixture()
def reset_db(ml_dag_repository: MLDagRepository) -> None:
    """ Resets DB before each test to initial testing state """
    ml_dag_repository.metadata.drop_all()