Example #1
0
def create_initial_application_acl(mapper, connection, target):
    if target.application_type == SVN:
        acl_rules = [
                    ('internal_developer', 'edit'),
                    ('internal_developer', 'view'),
                    ('external_developer', 'edit'),
                    ('external_developer', 'view'),
                    ]
    else:
        acl_rules = [
                    ('internal_developer', 'view'),
                    ('external_developer', 'view'),
                    ('secretary', 'view'),
                    ('secretary', 'edit'),
                    ]

    if target.application_type == 'trac':
        acl_rules.append(('customer', 'view'))

    for role_id, permission_name in acl_rules:
        acl = Session.object_session(target).query(ApplicationACL).get((target.id, role_id, permission_name))
        if not acl:
            acl = ApplicationACL(application_id=target.id,
                                 role_id=role_id,
                                 permission_name=permission_name)
            Session.object_session(target).add(acl)
        else:
            # XXX this should not happen.
            pass
Example #2
0
def list_videos(session: Session, count):
    if int(count) > 0:
        qry = reversed(session.query(Video).order_by(Video.id.desc()).limit(int(count)).all())
    else:
        qry = session.query(Video).order_by(Video.id.asc())

    for vid in qry:
        print("Video %s:" % vid.id)
        print("\tTitle: %s" % vid.title)
        print("\tDate Added: %s" % vid.date_added)
        if vid.description:
            print("\tDescription:\n\t\t%s" % vid.description.replace("\n", "\n\t\t"))
        else:
            print("\tDescription: None")
        print("\tKeywords: %s" % vid.keywords)
        print("\tGame: %s" % vid.game)
        print("\tStart Segment: %s" % vid.start_segment)
        print("\tEnd Segment: %s" % vid.end_segment)
        print("\tProcessing:")
        print("\t\tDone: %s" % vid.processing_done)
        print("\t\tStatus: %s" % str(vid.processing_status).replace("\n", "\n\t\t\t"))
        print("\tTwitch:")
        print("\t\tDo: %s" % vid.do_twitch)
        print("\t\tDone: %s" % vid.done_twitch)
        print("\t\tStatus: %s" % str(vid.twitch_status).replace("\n", "\n\t\t\t"))
        print("\tYouTube:")
        print("\t\tDo: %s" % vid.do_youtube)
        print("\t\tDone: %s" % vid.done_youtube)
        print("\t\tStatus: %s" % str(vid.youtube_status).replace("\n", "\n\t\t\t"))
        print("\t\tPublish Date: %s" % vid.youtube_pubdate)
        print("\t\tPublic: %s" % vid.youtube_public)
        if vid.comment:
            print("\tComment:\n\t\t%s" % vid.comment.replace("\n", "\n\t\t"))
Example #3
0
class SbciFinanceDB(object):
    '''TODO'''

    def __init__(self, verbose=0, *args, **kwds):  # @UnusedVariable
        super(SbciFinanceDB, self).__init__(*args, **kwds)

        if not os.access(FINANCEDB_FILE, os.R_OK | os.W_OK):
            raise RuntimeError('cannot access Finance DB file ({}) for R/W!'
                               .format(FINANCEDB_FILE))

        self.Base = automap_base()

        self.engine = create_engine('sqlite:///' + FINANCEDB_FILE)

        self.Base.prepare(self.engine, reflect=True)

        self.Categories = self.Base.classes.categories
        self.Seasons = self.Base.classes.seasons
        self.Cheques = self.Base.classes.cheques
        self.Transactions = self.Base.classes.transactions
        self.Trybooking = self.Base.classes.trybooking

        self.dbsession = Session(self.engine)

        self.categories_query = self.dbsession.query(self.Categories)
        self.seasons_query = self.dbsession.query(self.Seasons)
        self.cheques_query = self.dbsession.query(self.Cheques)
        self.transactions_query = self.dbsession.query(self.Transactions)
        self.trybooking_query = self.dbsession.query(self.Trybooking)
 def __init__(self, db, autocommit=False, autoflush=False, **options):
     self.app = db.get_app()
     self._model_changes = {}
     Session.__init__(self, autocommit=autocommit, autoflush=autoflush,
                      extension=db.session_extensions,
                      bind=db.engine,
                      binds=db.get_binds(self.app), **options)
def test_ui_cancel_withdraw(
    logged_in_wallet_user_browser: DriverAPI, dbsession: Session, user_phone_number, top_up_user, eth_asset_id
):
    """Create new account through UI."""

    # Record balance before cancel
    with transaction.manager:
        uca = dbsession.query(UserCryptoAddress).first()
        asset = dbsession.query(Asset).get(eth_asset_id)
        original_balance = uca.get_crypto_account(asset).account.get_balance()

    # Go to address
    b = logged_in_wallet_user_browser
    b.find_by_css("#nav-wallet").click()
    b.find_by_css("#row-asset-{} a.withdraw-asset".format(eth_asset_id)).click()

    b.fill("address", TEST_ADDRESS)
    b.fill("amount", "0.1")
    b.find_by_css("button[name='process']").click()

    # We should arrive to the confirmation page
    assert b.is_element_present_by_css("#heading-confirm-withdraw")

    b.find_by_css("button[name='cancel']").click()

    # We got text box telling withdraw was cancelled
    assert b.is_element_present_by_css("#msg-withdraw-cancelled")

    # Balance should be back to original
    with transaction.manager:
        uca = dbsession.query(UserCryptoAddress).first()
        asset = dbsession.query(Asset).get(eth_asset_id)
        assert original_balance == uca.get_crypto_account(asset).account.get_balance()
Example #6
0
def init_settings(session: Session):
    set_new_setting(session, "ffmpeg_path", "ffmpeg")
    set_new_setting(session, "ffmpeg_verbose", False)
    set_new_setting(session, "input_directory", "/set/me")
    set_new_setting(session, "output_directory", "esa_videos")
    set_new_setting(session, "final_output_directory", "esa_final_videos")
    set_new_setting(session, "common_keywords", "ESA")
    set_new_setting(session, "input_video_extension", ".flv")
    set_new_setting(session, "output_video_extension", ".mp4")
    set_new_setting(session, "output_video_prefix", "esa_")
    set_new_setting(session, "upload_logfile", "esa_upload.log")
    set_new_setting(session, "youtube_public", False)
    set_new_setting(session, "youtube_schedule", False)
    set_new_setting(session, "youtube_schedule_start_ts", 0)
    set_new_setting(session, "schedule_offset_hours", 12)
    set_new_setting(session, "upload_to_twitch", True)
    set_new_setting(session, "upload_to_youtube", True)
    set_new_setting(session, "delete_after_upload", True)
    set_new_setting(session, "upload_ratelimit", 2.5)
    set_new_setting(session, "upload_retry_count", 3)
    set_new_setting(session, "twitch_ul_pid", 0)
    set_new_setting(session, "twitch_ul_create_time", 0.0)
    set_new_setting(session, "youtube_ul_pid", 0)
    set_new_setting(session, "youtube_ul_create_time", 0.0)
    set_new_setting(session, "proc_ul_pid", 0)
    set_new_setting(session, "proc_ul_create_time", 0.0)
    set_new_setting(session, "init_proc_status", "")
    session.commit()
Example #7
0
def post_process_video(session: Session, debug):
    vids = session.query(Video)
    vids = vids.filter(Video.processing_done)
    vids = vids.filter(not_(Video.post_processing_done))
    vids = vids.filter(or_(and_(Video.do_twitch, Video.done_twitch), not_(Video.do_twitch)))
    vids = vids.filter(or_(and_(Video.do_youtube, Video.done_youtube), not_(Video.do_youtube)))
    vids = vids.order_by(Video.id.asc())
    vid = vids.first()

    if not vid:
        print("No video in need of processing found")
        return 0

    out_dir = get_setting(session, "output_directory")
    final_dir = get_setting(session, "final_output_directory")
    out_prefix = get_setting(session, "output_video_prefix")
    out_ext = get_setting(session, "output_video_extension")

    out_fname = "%s%s%s" % (out_prefix, vid.id, out_ext)
    out_path = os.path.join(out_dir, out_fname)
    final_path = os.path.join(final_dir, out_fname)

    if out_path != final_path:
        shutil.move(out_path, final_path)
        vid.post_processing_status = "Moved %s to %s" % (out_path, final_path)
    else:
        vid.post_processing_status = "Nothing to do"

    vid.post_processing_done = True
    session.commit()

    return 0
Example #8
0
class TrialModelTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        global transaction, connection, engine

        # Connect to the database and create the schema within a transaction
        engine = create_engine(TEST_DATABASE_URI)
        connection = engine.connect()
        transaction = connection.begin()
        Trial.metadata.create_all(connection)

        # Load test trials fixtures from xml files
        nct_ids = ['NCT02034110', 'NCT00001160', 'NCT00001163']
        cls.trials = load_sample_trials(nct_ids)

    @classmethod
    def tearDownClass(cls):
        # Roll back the top level transaction and disconnect from the database
        transaction.rollback()
        connection.close()
        engine.dispose()

    def setUp(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def tearDown(self):
        self.session.close()
        self.__transaction.rollback()

    def test_add(self):
        trial = Trial(ct_dict=self.trials[0])
        self.session.add(trial)
Example #9
0
 def __init__(self, db, autocommit=False, autoflush=False, **options):
     self.sender = db.sender
     self._model_changes = {}
     Session.__init__(self, autocommit=autocommit, autoflush=autoflush,
                      expire_on_commit=False,
                      extension=db.session_extensions,
                      bind=db.engine, **options)
    def get_document_rows(self, keywords, domains=set([]), dmset = set([])):
        # Create a new session
        session = Session(bind = engine)

        # Look up the article keywords
        kres = KeywordIDResolutionService()
        _keywords = {k : self._kres.resolve(k) for k in keywords}
        resolved = 0
        for k in _keywords:
            if _keywords[k] is None:
                yield QueryMessage("No matching keyword: %s", k)
            else:
                resolved += 1
        if resolved == 0:
            raise QueryException("No matching keywords.")

        # Find the sites which talk about a particular keyword 
        sql = """ SELECT domains.`key`, COUNT(*) AS c from domains JOIN articles ON articles.domain_id = domains.id 
            JOIN documents ON documents.article_id = articles.id 
            JOIN keyword_adjacencies ON keyword_adjacencies.doc_id = documents.id 
            WHERE keyword_adjacencies.key1_id IN (:keys)
            OR keyword_adjacencies.key2_id IN (:keys)
            GROUP BY domains.id 
            ORDER BY c DESC 
            LIMIT 0,5
        """
        for key, count in session.execute(sql, ({'keys': ','.join([str(i) for i in _keywords.values()])})):
            logging.info((key, count))
            domains.add(key)

        return self._kd_proc.get_document_rows(keywords, domains, dmset)
Example #11
0
def cache_keywords():
    core.configure_logging('debug')
    from backend.db import Keyword
    engine = core.get_database_engine_string()
    logging.info("Using connection string '%s'" % (engine,))
    engine = create_engine(engine, encoding='utf-8', isolation_level="READ UNCOMMITTED")
    session = Session(bind=engine, autocommit = False)

    # Estimate the number of keywords
    logging.debug("Estimating number of keywords...")
    for count, in session.execute("SELECT COUNT(*) FROM keywords"):
        total = count 

    logging.debug("Establishing connection to redis...")
    r = get_redis_instance(1)

    logging.info("Caching %d keywords...", total)
    cached = 0
    for _id, word in session.execute("SELECT id, word FROM keywords"):
        assert r.set(word, _id)
        cached += 1
        if cached % 1000 == 0:
            logging.info("Cached %d keywords (%.2f%% done)", cached, 100.0*cached/total)

    logging.info("Cached %d keywords (%.2f%% done)", cached, 100.0*cached/total)
Example #12
0
 def __init__(self, db, autocommit=False, autoflush=True, **options):
     #: The application that this session belongs to.
     self.app = db.get_app()
     bind = options.pop('bind', None) or db.engine
     SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush,
                          bind=bind,
                          binds=db.get_binds(self.app), **options)
    def test_get_object(self):
        """
            Test the method _get_object() using a sync key
            Test scenario:
            Get the object with sync_key
        """

        obj_type = self.powerVCMapping.obj_type
        sync_key = self.powerVCMapping.sync_key

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(
            obj_type=obj_type, sync_key=sync_key).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.ReplayAll()
        returnValue = self.powervcagentdb._get_object(
            obj_type=obj_type, sync_key=sync_key)
        self.aMox.VerifyAll()
        self.assertEqual(returnValue, self.powerVCMapping)
        self.aMox.UnsetStubs()
    def test_set_object_local_id(self):
        """
            Test the method _set_object_local_id(self, obj, local_id)
            Test scenario:
            Set the local_id of the specified object when the pvc_id is none
        """

        obj_id = self.powerVCMapping.id
        self.powerVCMapping.pvc_id = None
        self.powerVCMapping.local_id = None
        self.powerVCMapping.status = None

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(id=obj_id).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.StubOutWithMock(session, 'merge')
        session.merge(self.powerVCMapping).AndReturn("")

        self.aMox.ReplayAll()
        self.powervcagentdb._set_object_local_id(self.powerVCMapping, 'test')
        self.aMox.VerifyAll()
        self.assertEqual(self.powerVCMapping.status, 'Creating')
        self.assertEqual(self.powerVCMapping.local_id, 'test')
        self.aMox.UnsetStubs()
    def test_delete_existing_object(self):
        """
            Test the method _delete_object(self, obj) when the object exists
            Test scenario:
            When the data is in the database, the delete operation should
            complete successfully
        """

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(id=self.powerVCMapping['id']).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.StubOutWithMock(session, 'begin')
        session.begin(subtransactions=True).AndReturn(transaction(None, None))

        self.aMox.StubOutWithMock(session, 'delete')
        returnValue = session.delete(self.powerVCMapping).AndReturn(True)

        self.aMox.ReplayAll()

        self.powervcagentdb._delete_object(self.powerVCMapping)

        self.aMox.VerifyAll()

        self.assertEqual(returnValue, True)

        self.aMox.UnsetStubs()
Example #16
0
class DBBackedController(object):

    def __init__(self, engine, session=None):

        if type(engine) == types.StringType:
            logging.info("Using connection string '%s'" % (engine,))
            new_engine = create_engine(engine, encoding='utf-8')
            if "sqlite:" in engine:
                logging.debug("Setting text factory for unicode compat.")
                new_engine.raw_connection().connection.text_factory = str 
            self._engine = new_engine

        else:
            logging.info("Using existing engine...")
            self._engine = engine

        if session is None:
            logging.info("Binding session...")
            self._session = Session(bind=self._engine)
        else:
            self._session = session

        logging.info("Updating metadata...")
        Base.metadata.create_all(self._engine)

    def commit(self):
        logging.info("Commiting...")
        self._session.commit()
class ServiceTest(unittest.TestCase):
    def setup(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def teardown(self):
        self.session.close()
        self.__transaction.rollback()
 def update_translation_table():
     from camelot.model.i18n import Translation
     from sqlalchemy.orm.session import Session
     t = Translation.get_by(source=source, language=language)
     if not t:
         t = Translation(source=source, language=language)
     t.value = value
     Session.object_session( t ).flush( [t] )
Example #19
0
def session(engine):
    session = Session(engine)

    yield session

    session.close()

    exec_sql("SELECT truncate_tables('osm_test')")
 def __init__(self, db, autocommit=False, autoflush=False, **options):
     #import pdb; pdb.set_trace() # this is temporary!
     self.app = db.get_app()
     self._model_changes = {}
     Session.__init__(self, autocommit=autocommit, autoflush=autoflush,
                      extension=db.session_extensions,
                      bind=db.engine,
                      binds=db.get_binds(self.app), **options)
Example #21
0
File: query.py Project: badock/rome
 def delete(self, synchronize_session='evaluate'):
     from rome.core.session.session import Session
     temporary_session = Session()
     objects = self.matching_objects(filter_deleted=False)
     for obj in objects:
         temporary_session.delete(obj)
     temporary_session.flush()
     return len(objects)
Example #22
0
class DatabaseTest(object):
    def setup(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def teardown(self):
        self.session.close()
        self.__transaction.rollback()
Example #23
0
class DatabaseTest(unittest.TestCase):

    def setUp(self):
        self.trans = connection.begin()
        self.session = Session(connection)

    def tearDown(self):
        self.trans.rollback()
        self.session.close()
Example #24
0
def create_initial_kanban_acl(mapper, connection, target):
    acl_rules = [
                 ('role:redturtle_developer', 'view'),
                ]

    for principal_id, permission_name in acl_rules:
        acl = KanbanACL(principal=principal_id,
                board_id=target.id,
                permission_name=permission_name)
        Session.object_session(target).add(acl)
Example #25
0
    def __init__(self, db, autocommit=False, autoflush=True, **options):
        #: The application that this session belongs to.
        self.app = app = db.get_app()
        track_modifications = app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]
        bind = options.pop("bind", None) or db.engine
        binds = options.pop("binds", db.get_binds(app))

        if track_modifications is None or track_modifications:
            _SessionSignalEvents.register(self)

        SessionBase.__init__(self, autocommit=autocommit, autoflush=autoflush, bind=bind, binds=binds, **options)
Example #26
0
            def request():
                from sqlalchemy.orm.session import Session
                from camelot.view.remote_signals import get_signal_handler

                o = entity_getter()
                self._model_function(o)
                if self._flush:
                    sh = get_signal_handler()
                    Session.object_session(o).flush([o])
                    sh.sendEntityUpdate(self, o)
                return True
Example #27
0
class DatabaseTest(TestCase):
    """
    Base class for our tests that involve the database.
    """
    def setUp(self):
        self.__transaction = connection.begin_nested()
        self.session = Session(connection)

    def tearDown(self):
        self.session.close()
        self.__transaction.rollback()
 def run(self):
     from sqlalchemy.orm.session import Session
     from camelot.view.remote_signals import get_signal_handler
     signal_handler = get_signal_handler()
     collection = list(self._collection_getter())
     self.update_maximum_signal.emit( len(collection) )
     for i, entity in enumerate(collection):
         message = self.update_entity(entity)
         Session.object_session( entity ).flush( [entity] )
         signal_handler.sendEntityUpdate( self, entity )
         self.update_progress_signal.emit( i, message or '')
Example #29
0
    def add_clarification_question(self, question, answer):
        clarification_question = BriefClarificationQuestion(
            brief=self,
            question=question,
            answer=answer,
        )
        clarification_question.validate()

        Session.object_session(self).add(clarification_question)

        return clarification_question
Example #30
0
    def job_status(self, job_id):
        """Return information about a job request"""
        s = Session()
        r = s.query(JobRequest).get(job_id)

        if not r:
            return self._failed("Job %s not found" % job_id, 404)

        retval = r.asDict()

        return self._ok(retval)
Example #31
0
 def get(session: Session, club_id: int):
     return session.query(Club).filter(Club.id == club_id).first()
Example #32
0
 def getAll(session: Session) -> list:
     return session.query(Club).all()
Example #33
0
    def update_state(
        self,
        session: Session = None,
        execute_callbacks: bool = True
    ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
            directly (default: true) or recorded as a pending request in the ``callback`` property
        :type execute_callbacks: bool
        :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
            needs to be executed
        """
        # Callback to execute in case of Task Failures
        callback: Optional[callback_requests.DagCallbackRequest] = None

        start_dttm = timezone.utcnow()
        self.last_scheduling_decision = start_dttm

        dag = self.get_dag()
        ready_tis: List[TI] = []
        tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,)))
        self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
        for ti in tis:
            ti.task = dag.get_task(ti.task_id)

        unfinished_tasks = [t for t in tis if t.state in State.unfinished]
        finished_tasks = [t for t in tis if t.state in State.finished | {State.UPSTREAM_FAILED}]
        none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
        none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)
        if unfinished_tasks:
            scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
            self.log.debug(
                "number of scheduleable tasks for %s: %s task(s)",
                self, len(scheduleable_tasks))
            ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
            self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis))
            if none_depends_on_past and none_task_concurrency:
                # small speed up
                are_runnable_tasks = ready_tis or self._are_premature_tis(
                    unfinished_tasks, finished_tasks, session) or changed_tis

        duration = (timezone.utcnow() - start_dttm)
        Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tasks and any(
            leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis
        ):
            self.log.error('Marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self, success=False, reason='task_failure', session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='task_failure'
                )

        # if all leafs succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tasks and all(
            leaf_ti.state in {State.SUCCESS, State.SKIPPED} for leaf_ti in leaf_tis
        ):
            self.log.info('Marking run %s successful', self)
            self.set_state(State.SUCCESS)
            if execute_callbacks:
                dag.handle_callback(self, success=True, reason='success', session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=False,
                    msg='success'
                )

        # if *all tasks* are deadlocked, the run failed
        elif (unfinished_tasks and none_depends_on_past and
              none_task_concurrency and not are_runnable_tasks):
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='all_tasks_deadlocked'
                )

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(State.RUNNING)

        self._emit_duration_stats_for_finished_state()

        session.merge(self)

        return ready_tis, callback
Example #34
0
def test_get_user_by_phonenumber(user: User, session: Session,
                                 user_service: UserService):
    assert session.query(User).count() == 1
    xuser = user_service.get_by_phone_number(session, user.phone_number)
    assert xuser is not None
    assert user == xuser
def upgrade():
    # ### commands auto generated by Alembic - please adjust! ###
    op.drop_constraint('datasets_user_id_users_fkey',
                       'datasets',
                       schema='sampledb',
                       type_='foreignkey')
    session = Session(bind=op.get_bind())
    session.execute(f"DROP VIEW IF EXISTS {DatasetView.__table__}")
    session.commit()

    result = session.execute(
        """SELECT con.conname, information_schema.tables.table_name
                   FROM pg_catalog.pg_constraint con
                        INNER JOIN pg_catalog.pg_class rel
                                   ON rel.oid = con.conrelid
                        INNER JOIN pg_catalog.pg_namespace nsp
                                   ON nsp.oid = connamespace
                        INNER JOIN information_schema.tables
                                    ON  information_schema.tables.table_schema = nsp.nspname
                                    AND  nsp.nspname = 'sampledb'
                                    AND rel.relname = information_schema.tables.table_name
                   WHERE information_schema.tables.table_name ~ '^dataset_'
                   AND con.conname ~ 'user_id_fkey$';
                   """)

    for r in result:
        op.drop_constraint(f'{r[0]}',
                           f'{r[1]}',
                           schema='sampledb',
                           type_='foreignkey')
        session.execute(f"DROP VIEW IF EXISTS sampledb.v_{r[1]};")
    session.execute(f"DROP TABLE sampledb.users ;")
    session.commit()

    session.commit()
Example #36
0
    def update_state(
        self,
        session: Session = NEW_SESSION,
        execute_callbacks: bool = True
    ) -> Tuple[List[TI], Optional[DagCallbackRequest]]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
            directly (default: true) or recorded as a pending request in the ``callback`` property
        :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
            needs to be executed
        """
        # Callback to execute in case of Task Failures
        callback: Optional[DagCallbackRequest] = None

        start_dttm = timezone.utcnow()
        self.last_scheduling_decision = start_dttm
        with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"):
            dag = self.get_dag()
            info = self.task_instance_scheduling_decisions(session)

            tis = info.tis
            schedulable_tis = info.schedulable_tis
            changed_tis = info.changed_tis
            finished_tis = info.finished_tis
            unfinished_tis = info.unfinished_tis

            none_depends_on_past = all(not t.task.depends_on_past
                                       for t in unfinished_tis)
            none_task_concurrency = all(t.task.max_active_tis_per_dag is None
                                        for t in unfinished_tis)
            none_deferred = all(t.state != State.DEFERRED
                                for t in unfinished_tis)

            if unfinished_tis and none_depends_on_past and none_task_concurrency and none_deferred:
                # small speed up
                are_runnable_tasks = (schedulable_tis
                                      or self._are_premature_tis(
                                          unfinished_tis, finished_tis,
                                          session) or changed_tis)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tis and any(leaf_ti.state in State.failed_states
                                      for leaf_ti in leaf_tis):
            self.log.error('Marking run %s failed', self)
            self.set_state(DagRunState.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='task_failure',
                                    session=session)
            elif dag.has_on_failure_callback:
                callback = DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    run_id=self.run_id,
                    is_failure_callback=True,
                    msg='task_failure',
                )

        # if all leaves succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tis and all(leaf_ti.state in State.success_states
                                        for leaf_ti in leaf_tis):
            self.log.info('Marking run %s successful', self)
            self.set_state(DagRunState.SUCCESS)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=True,
                                    reason='success',
                                    session=session)
            elif dag.has_on_success_callback:
                callback = DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    run_id=self.run_id,
                    is_failure_callback=False,
                    msg='success',
                )

        # if *all tasks* are deadlocked, the run failed
        elif (unfinished_tis and none_depends_on_past and none_task_concurrency
              and none_deferred and not are_runnable_tasks):
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(DagRunState.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='all_tasks_deadlocked',
                                    session=session)
            elif dag.has_on_failure_callback:
                callback = DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    run_id=self.run_id,
                    is_failure_callback=True,
                    msg='all_tasks_deadlocked',
                )

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(DagRunState.RUNNING)

        if self._state == DagRunState.FAILED or self._state == DagRunState.SUCCESS:
            msg = ("DagRun Finished: dag_id=%s, execution_date=%s, run_id=%s, "
                   "run_start_date=%s, run_end_date=%s, run_duration=%s, "
                   "state=%s, external_trigger=%s, run_type=%s, "
                   "data_interval_start=%s, data_interval_end=%s, dag_hash=%s")
            self.log.info(
                msg,
                self.dag_id,
                self.execution_date,
                self.run_id,
                self.start_date,
                self.end_date,
                (self.end_date - self.start_date).total_seconds()
                if self.start_date and self.end_date else None,
                self._state,
                self.external_trigger,
                self.run_type,
                self.data_interval_start,
                self.data_interval_end,
                self.dag_hash,
            )
            session.flush()

        self._emit_true_scheduling_delay_stats_for_finished_state(finished_tis)
        self._emit_duration_stats_for_finished_state()

        session.merge(self)
        # We do not flush here for performance reasons(It increases queries count by +20)

        return schedulable_tis, callback
Example #37
0
    def addHosts(self, session: Session, addHostRequest: dict) -> None:
        """
        Raises:
            HardwareProfileNotFound
            ResourceAdapterNotFound
        """

        self._logger.debug('addHosts()')

        dbHardwareProfile = \
            HardwareProfilesDbHandler().getHardwareProfile(
                session, addHostRequest['hardwareProfile'])

        if not dbHardwareProfile.resourceadapter:
            errmsg = ('Resource adapter not defined for hardware'
                      ' profile [%s]' % (dbHardwareProfile.name))

            self._logger.error(errmsg)

            raise ResourceAdapterNotFound(errmsg)

        softwareProfileName = addHostRequest['softwareProfile'] \
            if 'softwareProfile' in addHostRequest else None

        dbSoftwareProfile = \
            SoftwareProfilesDbHandler().getSoftwareProfile(
                session, softwareProfileName) \
            if softwareProfileName else None

        ResourceAdapterClass = \
            resourceAdapterFactory.get_resourceadapter_class(
                dbHardwareProfile.resourceadapter.name)

        resourceAdapter = ResourceAdapterClass(
            addHostSession=addHostRequest['addHostSession'])

        resourceAdapter.session = session

        # Call the start() method of the resource adapter
        newNodes = resourceAdapter.start(addHostRequest,
                                         session,
                                         dbHardwareProfile,
                                         dbSoftwareProfile=dbSoftwareProfile)

        session.add_all(newNodes)
        session.flush()

        if 'tags' in addHostRequest and addHostRequest['tags']:
            for node in newNodes:
                self._set_tags(node, addHostRequest['tags'])

        # Commit new node(s) to database
        session.commit()

        # Only perform post-add operations if we actually added a node
        if newNodes:
            self._logger.info(
                'Node(s) added to software profile [%s] and'
                ' hardware profile [%s]',
                dbSoftwareProfile.name if dbSoftwareProfile else 'None',
                dbHardwareProfile.name,
            )

            newNodeNames = [tmpNode.name for tmpNode in newNodes]

            resourceAdapter.hookAction('add', newNodeNames)

            self.postAddHost(session, dbHardwareProfile.name,
                             softwareProfileName,
                             addHostRequest['addHostSession'])

            resourceAdapter.hookAction('start', newNodeNames)

        self._logger.debug('Add host workflow complete')
Example #38
0
 def delete(cls, id_):
     obj = cls.query().get(id_)
     Session.delete(obj)
     Session.commit()
Example #39
0
 def getAll(session: Session):
     return session.query(Site_category)
Example #40
0
 def session(self):
     return Session.object_session(self)
Example #41
0
    def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]:
        """Create the mapped task instances for mapped task.

        :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the
            maximum map_index.
        """
        from airflow.models.taskinstance import TaskInstance
        from airflow.settings import task_instance_mutation_hook

        total_length = functools.reduce(
            operator.mul, self._resolve_map_lengths(run_id, session=session).values()
        )

        state: Optional[TaskInstanceState] = None
        unmapped_ti: Optional[TaskInstance] = (
            session.query(TaskInstance)
            .filter(
                TaskInstance.dag_id == self.dag_id,
                TaskInstance.task_id == self.task_id,
                TaskInstance.run_id == run_id,
                TaskInstance.map_index == -1,
                or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
            )
            .one_or_none()
        )

        all_expanded_tis: List[TaskInstance] = []

        if unmapped_ti:
            # The unmapped task instance still exists and is unfinished, i.e. we
            # haven't tried to run it before.
            if total_length < 1:
                # If the upstream maps this to a zero-length value, simply marked the
                # unmapped task instance as SKIPPED (if needed).
                self.log.info(
                    "Marking %s as SKIPPED since the map has %d values to expand",
                    unmapped_ti,
                    total_length,
                )
                unmapped_ti.state = TaskInstanceState.SKIPPED
            else:
                # Otherwise convert this into the first mapped index, and create
                # TaskInstance for other indexes.
                unmapped_ti.map_index = 0
                self.log.debug("Updated in place to become %s", unmapped_ti)
                all_expanded_tis.append(unmapped_ti)
            state = unmapped_ti.state
            indexes_to_map = range(1, total_length)
        else:
            # Only create "missing" ones.
            current_max_mapping = (
                session.query(func.max(TaskInstance.map_index))
                .filter(
                    TaskInstance.dag_id == self.dag_id,
                    TaskInstance.task_id == self.task_id,
                    TaskInstance.run_id == run_id,
                )
                .scalar()
            )
            indexes_to_map = range(current_max_mapping + 1, total_length)

        for index in indexes_to_map:
            # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
            ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
            self.log.debug("Expanding TIs upserted %s", ti)
            task_instance_mutation_hook(ti)
            ti = session.merge(ti)
            ti.refresh_from_task(self)  # session.merge() loses task information.
            all_expanded_tis.append(ti)

        # Set to "REMOVED" any (old) TaskInstances with map indices greater
        # than the current map value
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == self.dag_id,
            TaskInstance.task_id == self.task_id,
            TaskInstance.run_id == run_id,
            TaskInstance.map_index >= total_length,
        ).update({TaskInstance.state: TaskInstanceState.REMOVED})

        session.flush()
        return all_expanded_tis, total_length
Example #42
0
 def for_library(cls, key, library):
     """Find or create a ConfigurationSetting for the given Library."""
     _db = Session.object_session(library)
     return cls.for_library_and_externalintegration(_db, key, library, None)
Example #43
0
    def _find_scheduled_tasks(
            self,
            dag_run: DagRun,
            session: Session,
            check_execution_date=False) -> Optional[List[TI]]:
        """
        Make scheduling decisions about an individual dag run

        ``currently_active_runs`` is passed in so that a batch query can be
        used to ask this for all dag runs in the batch, to avoid an n+1 query.

        :param dag_run: The DagRun to schedule
        :return: scheduled tasks
        """
        if not dag_run or dag_run.get_state() in State.finished:
            return
        try:
            dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id,
                                                    session=session)
        except SerializedDagNotFound:
            self.log.exception("DAG '%s' not found in serialized_dag table",
                               dag_run.dag_id)
            return None

        if not dag:
            self.log.error("Couldn't find dag %s in DagBag/DB!",
                           dag_run.dag_id)
            return None

        currently_active_runs = session.query(TI.execution_date, ).filter(
            TI.dag_id == dag_run.dag_id,
            TI.state.notin_(list(State.finished)),
        ).all()

        if check_execution_date and dag_run.execution_date > timezone.utcnow(
        ) and not dag.allow_future_exec_dates:
            self.log.warning("Execution date is in future: %s",
                             dag_run.execution_date)
            return None

        if dag.max_active_runs:
            if (len(currently_active_runs) >= dag.max_active_runs
                    and dag_run.execution_date not in currently_active_runs):
                self.log.info(
                    "DAG %s already has %d active runs, not queuing any tasks for run %s",
                    dag.dag_id,
                    len(currently_active_runs),
                    dag_run.execution_date,
                )
                return None

        self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)

        schedulable_tis, callback_to_run = dag_run.update_state(
            session=session, execute_callbacks=False)
        dag_run.schedule_tis(schedulable_tis, session)
        session.commit()

        query = (session.query(TI).outerjoin(TI.dag_run).filter(
            or_(DR.run_id.is_(None),
                DR.run_type != DagRunType.BACKFILL_JOB)).join(
                    TI.dag_model).filter(not_(DM.is_paused)).filter(
                        TI.state == State.SCHEDULED).options(
                            selectinload('dag_model')))
        scheduled_tis: List[TI] = with_row_locks(
            query,
            of=TI,
            **skip_locked(session=session),
        ).all()
        return scheduled_tis
Example #44
0
    def verify_integrity(self, session: Session = NEW_SESSION):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. Marking it as removed.",
                        ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously removed from DAG '%s'",
                    ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        def task_filter(task: "Operator") -> bool:
            return task.task_id not in task_ids and (
                self.is_backfill or task.start_date <= self.execution_date and
                (task.end_date is None
                 or self.execution_date <= task.end_date))

        created_counts: Dict[str, int] = defaultdict(int)

        # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
        hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)

        if hook_is_noop:

            def create_ti_mapping(task: "Operator") -> dict:
                created_counts[task.task_type] += 1
                return TI.insert_mapping(self.run_id, task, map_index=-1)

        else:

            def create_ti(task: "Operator") -> TI:
                ti = TI(task, run_id=self.run_id)
                task_instance_mutation_hook(ti)
                created_counts[ti.operator] += 1
                return ti

        # Create missing tasks
        tasks = list(filter(task_filter, dag.task_dict.values()))
        try:
            if hook_is_noop:
                session.bulk_insert_mappings(TI, map(create_ti_mapping, tasks))
            else:
                session.bulk_save_objects(map(create_ti, tasks))

            for task_type, count in created_counts.items():
                Stats.incr(f"task_instance_created-{task_type}", count)
            session.flush()
        except IntegrityError:
            self.log.info(
                'Hit IntegrityError while creating the TIs for %s- %s',
                dag.dag_id,
                self.run_id,
                exc_info=True,
            )
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
Example #45
0
    def find(
        dag_id: Optional[Union[str, List[str]]] = None,
        run_id: Optional[str] = None,
        execution_date: Optional[datetime] = None,
        state: Optional[str] = None,
        external_trigger: Optional[bool] = None,
        no_backfills: bool = False,
        run_type: Optional[DagRunType] = None,
        session: Session = None,
        execution_start_date: Optional[datetime] = None,
        execution_end_date: Optional[datetime] = None
    ) -> List["DagRun"]:
        """
        Returns a set of dag runs for the given search criteria.

        :param dag_id: the dag_id or list of dag_id to find dag runs for
        :type dag_id: str or list[str]
        :param run_id: defines the run id for this dag run
        :type run_id: str
        :param run_type: type of DagRun
        :type run_type: airflow.utils.types.DagRunType
        :param execution_date: the execution date
        :type execution_date: datetime.datetime or list[datetime.datetime]
        :param state: the state of the dag run
        :type state: str
        :param external_trigger: whether this dag run is externally triggered
        :type external_trigger: bool
        :param no_backfills: return no backfills (True), return all (False).
            Defaults to False
        :type no_backfills: bool
        :param session: database session
        :type session: sqlalchemy.orm.session.Session
        :param execution_start_date: dag run that was executed from this date
        :type execution_start_date: datetime.datetime
        :param execution_end_date: dag run that was executed until this date
        :type execution_end_date: datetime.datetime
        """
        DR = DagRun

        qry = session.query(DR)
        dag_ids = [dag_id] if isinstance(dag_id, str) else dag_id
        if dag_ids:
            qry = qry.filter(DR.dag_id.in_(dag_ids))
        if run_id:
            qry = qry.filter(DR.run_id == run_id)
        if execution_date:
            if isinstance(execution_date, list):
                qry = qry.filter(DR.execution_date.in_(execution_date))
            else:
                qry = qry.filter(DR.execution_date == execution_date)
        if execution_start_date and execution_end_date:
            qry = qry.filter(DR.execution_date.between(execution_start_date, execution_end_date))
        elif execution_start_date:
            qry = qry.filter(DR.execution_date >= execution_start_date)
        elif execution_end_date:
            qry = qry.filter(DR.execution_date <= execution_end_date)
        if state:
            qry = qry.filter(DR.state == state)
        if external_trigger is not None:
            qry = qry.filter(DR.external_trigger == external_trigger)
        if run_type:
            qry = qry.filter(DR.run_type == run_type.value)
        if no_backfills:
            qry = qry.filter(DR.run_type != DagRunType.BACKFILL_JOB.value)

        dr = qry.order_by(DR.execution_date).all()

        return dr
Example #46
0
def _fill_test_database(db: Session) -> NoReturn:
    """Create dummy users and channels to allow further testing in dev mode."""
    from quetz.dao import Dao

    test_users = []
    dao = Dao(db)
    try:
        for index, username in enumerate(['alice', 'bob', 'carol', 'dave']):
            user = dao.create_user_with_role(username)

            identity = Identity(
                provider='dummy',
                identity_id=str(index),
                username=username,
            )

            profile = Profile(name=username.capitalize(),
                              avatar_url='/avatar.jpg')

            user.identities.append(identity)  # type: ignore
            user.profile = profile
            db.add(user)
            test_users.append(user)

        for channel_index in range(3):
            channel = Channel(
                name=f'channel{channel_index}',
                description=f'Description of channel{channel_index}',
                private=False,
            )

            for package_index in range(random.randint(5, 10)):
                package = Package(
                    name=f'package{package_index}',
                    summary=f'package {package_index} summary text',
                    description=f'Description of package{package_index}',
                )
                channel.packages.append(package)  # type: ignore

                test_user = test_users[random.randint(0, len(test_users) - 1)]
                package_member = PackageMember(package=package,
                                               channel=channel,
                                               user=test_user,
                                               role='owner')

                db.add(package_member)

            test_user = test_users[random.randint(0, len(test_users) - 1)]

            if channel_index == 0:
                package = Package(name='xtensor',
                                  description='Description of xtensor')
                channel.packages.append(package)  # type: ignore

                package_member = PackageMember(package=package,
                                               channel=channel,
                                               user=test_user,
                                               role='owner')

                db.add(package_member)

                # create API key
                key = uuid.uuid4().hex

                key_user = User(id=uuid.uuid4().bytes)
                api_key = ApiKey(key=key,
                                 description='test API key',
                                 user=test_user,
                                 owner=test_user)
                db.add(api_key)
                print(
                    f'Test API key created for user "{test_user.username}": {key}'
                )

                key_package_member = PackageMember(
                    user=key_user,
                    channel_name=channel.name,
                    package_name=package.name,
                    role='maintainer',
                )
                db.add(key_package_member)

            db.add(channel)

            channel_member = ChannelMember(
                channel=channel,
                user=test_user,
                role='owner',
            )

            db.add(channel_member)
        db.commit()
    finally:
        db.close()
def downgrade():
    session = Session(bind=op.get_bind())
    q = Status.__table__.delete().where(
            Status.code.in_([StatusEnum.CREATED_FROM_FILE.name]))
    session.execute(q)
    session.commit()
#!/usr/bin/python3
"""Start link class to table in database
"""
import sys

from sqlalchemy.orm.session import Session
from model_state import Base, State
from sqlalchemy.orm import sessionmaker
from sqlalchemy import (create_engine)

if __name__ == "__main__":
    engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format(
        sys.argv[1], sys.argv[2], sys.argv[3]),
                           pool_pre_ping=True)
    Base.metadata.create_all(engine)

    Session = sessionmaker(bind=engine)
    session = Session()
    for state in session.query(State).filter(State.name.like("%a%")) \
            .order_by(State.id):
        print("{}: {}".format(state.id, state.name))
    session.close()
Example #49
0
 def add(self, session: Session):
     session.add(self)
def migrate_collections():
    session = Session(bind=op.get_bind())

    collection_map = {"dataset": DatasetORM, "reactiondataset": ReactionDatasetORM}

    collections = session.query(CollectionORM)

    for collection in collections:
        print(f"collection: id:{collection.id}, lname: {collection.lname}")

        if collection.collection in collection_map.keys():
            collection_class = collection_map[collection.collection]
        else:
            continue

        fields = collection.to_dict(exclude=["id"])
        session.query(CollectionORM).filter_by(id=collection.id).delete(synchronize_session=False)
        session.commit()

        dataset = collection_class(**fields)
        session.add(dataset)
        session.commit()

        dataset.update_relations(**fields)  # will use related attr
        session.commit()
Example #51
0
 def delete(session: Session, club_id: int) -> None:
     session.query(Club).filter(Club.id == club_id).delete()
Example #52
0
 def add(self, session: Session) -> None:
     session.add(self)
Example #53
0
 def query(cls):
     return Session.query(cls)
Example #54
0
 def create(self, session: Session, commit=True, **create_data: dict) -> T:
     model = self.model_calss(**create_data)
     session.add(model)
     if commit:
         session.commit()
     return model
#!/usr/bin/python3
"""Is a script that changes the name of a State object from the
database hbtn_0e_6_usa"""

from sqlalchemy.orm.session import Session
from sqlalchemy.sql.schema import MetaData

if __name__ == "__main__":

    from sqlalchemy import create_engine
    from sqlalchemy.orm import sessionmaker
    from sys import argv
    from model_state import Base, State

    engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format(
        argv[1], argv[2], argv[3]))

    Base.metadata.create_all(engine)
    Session = sessionmaker(bind=engine)

    session = Session()
    state = session.query(State).filter(State.id.like(2)).first()

    state.name = "New Mexico"

    session.commit()
    session.close
Example #56
0
def test_get_user_by_email(user: User, session: Session,
                           user_service: UserService):
    assert session.query(User).count() == 1
    xuser = user_service.get_by_email(session, user.email)
    assert xuser is not None
    assert user == xuser
Example #57
0
 def _dbsession(self):
     return Session.object_session(self.manager)
Example #58
0
""" lists all State objects from the database hbtn_0e_6_usa
"""
import sys
from sqlalchemy.orm import session

from sqlalchemy.orm.session import Session
from model_state import Base, State
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy import (create_engine)

if __name__ == "__main__":
    engine = create_engine('mysql+mysqldb://{}:{}@localhost/{}'.format(
        sys.argv[1], sys.argv[2], sys.argv[3]),
                           pool_pre_ping=True)

    # Base.metadata.create_all(engine)

    Session = sessionmaker()
    Session.configure(bind=engine)

    session = Session()

    result = session.query(State).first()
    if result is not None:
        print("{}: {}".format(result.id, result.name))
    else:
        print("Nothing")

    session.close()
Example #59
0
    def verify_integrity(self, session: Session = NEW_SESSION):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)

                should_restore_task = (task is not None) and ti.state == State.REMOVED
                if should_restore_task:
                    self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag)
                    Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.NONE
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED
                continue

            if not task.is_mapped:
                continue
            task = cast("MappedOperator", task)
            num_mapped_tis = task.parse_time_mapped_ti_count
            # Check if the number of mapped literals has changed and we need to mark this TI as removed
            if num_mapped_tis is not None:
                if ti.map_index >= num_mapped_tis:
                    self.log.debug(
                        "Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
                        ti,
                        num_mapped_tis,
                    )
                    ti.state = State.REMOVED
                elif ti.map_index < 0:
                    self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
                    ti.state = State.REMOVED
                else:
                    self.log.info("Restoring mapped task '%s'", ti)
                    Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.NONE
            else:
                #  What if it is _now_ dynamically mapped, but wasn't before?
                total_length = task.run_time_mapped_ti_count(self.run_id, session=session)

                if total_length is None:
                    # Not all upstreams finished, so we can't tell what should be here. Remove everything.
                    if ti.map_index >= 0:
                        self.log.debug(
                            "Removing the unmapped TI '%s' as the mapping can't be resolved yet", ti
                        )
                        ti.state = State.REMOVED
                    continue
                # Upstreams finished, check there aren't any extras
                if ti.map_index >= total_length:
                    self.log.debug(
                        "Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
                        ti,
                        total_length,
                    )
                    ti.state = State.REMOVED
                    ...

        def task_filter(task: "Operator") -> bool:
            return task.task_id not in task_ids and (
                self.is_backfill
                or task.start_date <= self.execution_date
                and (task.end_date is None or self.execution_date <= task.end_date)
            )

        created_counts: Dict[str, int] = defaultdict(int)

        # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
        hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)

        if hook_is_noop:

            def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
                created_counts[task.task_type] += 1
                for map_index in indexes:
                    yield TI.insert_mapping(self.run_id, task, map_index=map_index)

            creator = create_ti_mapping

        else:

            def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
                for map_index in indexes:
                    ti = TI(task, run_id=self.run_id, map_index=map_index)
                    task_instance_mutation_hook(ti)
                    created_counts[ti.operator] += 1
                    yield ti

            creator = create_ti

        # Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
        def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
            if not task.is_mapped:
                return (task, (-1,))
            task = cast("MappedOperator", task)
            count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count(
                self.run_id, session=session
            )
            if not count:
                return (task, (-1,))
            return (task, range(count))

        tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
        tasks = itertools.chain.from_iterable(itertools.starmap(creator, tasks_and_map_idxs))

        try:
            if hook_is_noop:
                session.bulk_insert_mappings(TI, tasks)
            else:
                session.bulk_save_objects(tasks)

            for task_type, count in created_counts.items():
                Stats.incr(f"task_instance_created-{task_type}", count)
            session.flush()
        except IntegrityError:
            self.log.info(
                'Hit IntegrityError while creating the TIs for %s- %s',
                dag.dag_id,
                self.run_id,
                exc_info=True,
            )
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
Example #60
0
    def manage_slas(self, dag: DAG, session: Session = None) -> None:
        """
        Finding all tasks that have SLAs defined, and sending alert emails
        where needed. New SLA misses are also recorded in the database.

        We are assuming that the scheduler runs often, so we only check for
        tasks that should have succeeded in the past hour.
        """
        self.log.info("Running SLA Checks for %s", dag.dag_id)
        if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
            self.log.info(
                "Skipping SLA check for %s because no tasks in DAG have SLAs",
                dag)
            return

        qry = (session.query(
            TI.task_id,
            func.max(DR.execution_date).label('max_ti')).join(
                TI.dag_run).with_hint(
                    TI, 'USE INDEX (PRIMARY)', dialect_name='mysql').filter(
                        TI.dag_id == dag.dag_id).filter(
                            or_(TI.state == State.SUCCESS,
                                TI.state == State.SKIPPED)).filter(
                                    TI.task_id.in_(dag.task_ids)).group_by(
                                        TI.task_id).subquery('sq'))
        # get recorded SlaMiss
        recorded_slas_query = set(
            session.query(SlaMiss.dag_id, SlaMiss.task_id,
                          SlaMiss.execution_date).filter(
                              SlaMiss.dag_id == dag.dag_id,
                              SlaMiss.task_id.in_(dag.task_ids)))

        max_tis: Iterator[TI] = (session.query(TI).join(TI.dag_run).filter(
            TI.dag_id == dag.dag_id,
            TI.task_id == qry.c.task_id,
            DR.execution_date == qry.c.max_ti,
        ))

        ts = timezone.utcnow()

        for ti in max_tis:
            task = dag.get_task(ti.task_id)
            if not task.sla:
                continue

            if not isinstance(task.sla, timedelta):
                raise TypeError(
                    f"SLA is expected to be timedelta object, got "
                    f"{type(task.sla)} in {task.dag_id}:{task.task_id}")

            sla_misses = []
            next_info = dag.next_dagrun_info(dag.get_run_data_interval(
                ti.dag_run),
                                             restricted=False)
            if next_info is None:
                self.log.info(
                    "Skipping SLA check for %s because task does not have scheduled date",
                    ti)
            else:
                while next_info.logical_date < ts:
                    next_info = dag.next_dagrun_info(next_info.data_interval,
                                                     restricted=False)

                    if next_info is None:
                        break
                    if (ti.dag_id, ti.task_id,
                            next_info.logical_date) in recorded_slas_query:
                        break
                    if next_info.logical_date + task.sla < ts:

                        sla_miss = SlaMiss(
                            task_id=ti.task_id,
                            dag_id=ti.dag_id,
                            execution_date=next_info.logical_date,
                            timestamp=ts,
                        )
                        sla_misses.append(sla_miss)
            if sla_misses:
                session.add_all(sla_misses)
        session.commit()

        slas: List[SlaMiss] = (
            session.query(SlaMiss).filter(SlaMiss.notification_sent == False,
                                          SlaMiss.dag_id == dag.dag_id)  # noqa
            .all())
        if slas:
            sla_dates: List[datetime.datetime] = [
                sla.execution_date for sla in slas
            ]
            fetched_tis: List[TI] = (session.query(TI).filter(
                TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates),
                TI.dag_id == dag.dag_id).all())
            blocking_tis: List[TI] = []
            for ti in fetched_tis:
                if ti.task_id in dag.task_ids:
                    ti.task = dag.get_task(ti.task_id)
                    blocking_tis.append(ti)
                else:
                    session.delete(ti)
                    session.commit()

            task_list = "\n".join(sla.task_id + ' on ' +
                                  sla.execution_date.isoformat()
                                  for sla in slas)
            blocking_task_list = "\n".join(ti.task_id + ' on ' +
                                           ti.execution_date.isoformat()
                                           for ti in blocking_tis)
            # Track whether email or any alert notification sent
            # We consider email or the alert callback as notifications
            email_sent = False
            notification_sent = False
            if dag.sla_miss_callback:
                # Execute the alert callback
                self.log.info('Calling SLA miss callback')
                try:
                    dag.sla_miss_callback(dag, task_list, blocking_task_list,
                                          slas, blocking_tis)
                    notification_sent = True
                except Exception:
                    self.log.exception(
                        "Could not call sla_miss_callback for DAG %s",
                        dag.dag_id)
            email_content = f"""\
            Here's a list of tasks that missed their SLAs:
            <pre><code>{task_list}\n<code></pre>
            Blocking tasks:
            <pre><code>{blocking_task_list}<code></pre>
            Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
            """

            tasks_missed_sla = []
            for sla in slas:
                try:
                    task = dag.get_task(sla.task_id)
                except TaskNotFound:
                    # task already deleted from DAG, skip it
                    self.log.warning(
                        "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.",
                        sla.task_id)
                    continue
                tasks_missed_sla.append(task)

            emails: Set[str] = set()
            for task in tasks_missed_sla:
                if task.email:
                    if isinstance(task.email, str):
                        emails |= set(get_email_address_list(task.email))
                    elif isinstance(task.email, (list, tuple)):
                        emails |= set(task.email)
            if emails:
                try:
                    send_email(emails,
                               f"[airflow] SLA miss on DAG={dag.dag_id}",
                               email_content)
                    email_sent = True
                    notification_sent = True
                except Exception:
                    Stats.incr('sla_email_notification_failure')
                    self.log.exception(
                        "Could not send SLA Miss email notification for DAG %s",
                        dag.dag_id)
            # If we sent any notification, update the sla_miss table
            if notification_sent:
                for sla in slas:
                    sla.email_sent = email_sent
                    sla.notification_sent = True
                    session.merge(sla)
            session.commit()