def verify_integrity(self, session: Session = None): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state is not State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. " "Marking it as removed.", ti, dag) Stats.incr("task_removed_from_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously " "removed from DAG '%s'", ti, dag) Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr("task_instance_created-{}".format(task.task_type), 1, 1) ti = TI(task, self.execution_date) task_instance_mutation_hook(ti) session.add(ti) try: session.commit() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') session.rollback()
def test_set_object_local_id(self): """ Test the method _set_object_local_id(self, obj, local_id) Test scenario: Set the local_id of the specified object when the pvc_id is none """ obj_id = self.powerVCMapping.id self.powerVCMapping.pvc_id = None self.powerVCMapping.local_id = None self.powerVCMapping.status = None self.aMox.StubOutWithMock(session, 'query') session.query(model.PowerVCMapping).AndReturn(query) self.aMox.StubOutWithMock(query, 'filter_by') query.filter_by(id=obj_id).AndReturn(query) self.aMox.StubOutWithMock(query, 'one') query.one().AndReturn(self.powerVCMapping) self.aMox.StubOutWithMock(session, 'merge') session.merge(self.powerVCMapping).AndReturn("") self.aMox.ReplayAll() self.powervcagentdb._set_object_local_id(self.powerVCMapping, 'test') self.aMox.VerifyAll() self.assertEqual(self.powerVCMapping.status, 'Creating') self.assertEqual(self.powerVCMapping.local_id, 'test') self.aMox.UnsetStubs()
def verify_integrity(self, session: Session = NEW_SESSION): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ from airflow.settings import task_instance_mutation_hook dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state != State.RUNNING and not dag.partial: self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) # check for missing tasks for task in dag.task_dict.values(): if task.start_date > self.execution_date and not self.is_backfill: continue if task.task_id not in task_ids: Stats.incr(f"task_instance_created-{task.task_type}", 1, 1) ti = TI(task, execution_date=None, run_id=self.run_id) task_instance_mutation_hook(ti) session.add(ti) try: session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback()
def synchronize_log_template(*, session: Session = NEW_SESSION) -> None: """Synchronize log template configs with table. This checks if the last row fully matches the current config values, and insert a new row if not. """ stored = session.query(LogTemplate).order_by(LogTemplate.id.desc()).first() filename = conf.get("logging", "log_filename_template") prefix = conf.get("logging", "task_log_prefix_template") if stored and stored.filename == filename and stored.task_prefix == prefix: return session.merge(LogTemplate(filename=filename, task_prefix=prefix))
def add(cls, session: Session, mpid: Union[int, str]) -> None: """ Add a record of a patient who wishes to opt out. Args: session: SQLAlchemy database session for the secret admin database mpid: MPID of the patient who is opting out """ log.debug(f"Adding opt-out for MPID {mpid}") # noinspection PyArgumentList newthing = cls(mpid=mpid) session.merge(newthing)
def downgrade(): session = Session(op.get_bind()) op.alter_column('data_source', 'datasource_configuration_id', nullable=True) for ds in session.query(DataSource).all(): ds.datasource_configuration_id = None session.merge(ds) session.query(DataSourceConfiguration).delete()
def synchronize_log_filename_template(*, session: Session = NEW_SESSION) -> None: """Synchronize log filename template config with table. This checks if the last row (based on timestamp) matches the current config value, and insert a new row if not. """ stored = session.query(LogFilename.template).order_by( LogFilename.id.desc()).limit(1).scalar() config = conf.get("logging", "LOG_FILENAME_TEMPLATE") if stored == config: return session.merge(LogFilename(template=config))
def insert_instruments(session: Session, dest_ric: Path, logger: Logger) -> None: with dest_ric.open(mode='r') as f: reader = csv.reader(f, delimiter=',') next(reader) for fields in reader: ric = fields[0] desc = fields[1] currency = fields[2] type_ = fields[3] exchange = fields[4] if type_ in [EQUITY, FUTURES] else None instrument = Instrument(ric, desc, currency, type_, exchange) session.merge(instrument) session.commit()
def draw_card(session: Session, order: Order) -> str: """根据给定的订单随机抽取一张卡片. ### Args: ``session``: 用于连接数据库的SQLAlchemy线程.\n ``order``: 待抽卡的订单.\n ### Result: ``message``: 抽卡后的反馈信息.\n """ if not hasattr(order, 'nickname'): order.nickname = '' # 抽取卡牌 threshold = float(setting.read_config('card', 'threshold')) divend = order.amount / threshold rand = abs(random.normalvariate(0, math.sqrt(divend))) if divend > 25: rand += math.log2(divend / 25) if 1 < rand <= 2.5: rand = 1 elif 2.5 < rand <= 5: rand = 2 elif rand > 5: rand = 3 rarity = int(rand) card_query = session.query(Card).filter(Card.rarity == rarity) card = card_query[random.randint(0, card_query.count() - 1)] # 按订单插入记录 session.add( Card_Order(order_id=order.id, rarity=rarity, type_id=card.type_id)) # 按用户插入记录 user = find_user(session, order.platform, order.user_id, order.nickname) order.nickname = user.nickname session.merge( Card_User(user_id=user.id, rarity=rarity, type_id=card.type_id)) session.flush() # 生成信息 collected_cards = session.query(Card_User).\ filter(Card_User.user_id == user.id).\ filter(Card_User.rarity == rarity).count() info_dict = { 'nickname': order.nickname, 'rarity': setting.rarity()[card.rarity], 'name': card.name, 'context': card.context, 'user_amount': collected_cards, 'total_amount': card_query.count(), 'image': f'[CQ:image,file={card.file_name}]', } logger.debug('%s抽取到一张%s卡:%s', user.nickname, info_dict['rarity'], card.name) pattern = setting.read_config('card', 'pattern') return pattern.format(**info_dict)
def _update_existing_or_create(ingested_entity: schema.JusticeCountsDatabaseEntity, session: Session) \ -> schema.JusticeCountsDatabaseEntity: # Note: Using on_conflict_do_update to resolve whether there is an existing entity could be more efficient as it # wouldn't incur multiple roundtrips. However for some entities we need to know whether there is an existing entity # (e.g. table instance) so we can clear child entities, so we probably wouldn't win much if anything. table = ingested_entity.__table__ [unique_constraint] = [ constraint for constraint in table.constraints if isinstance(constraint, UniqueConstraint) ] query = session.query(table) for column in unique_constraint: # TODO(#4477): Instead of making an assumption about how the property name is formed from the column name, use # an Entity method here to follow the foreign key relationship. if column.name.endswith('_id'): value = getattr(ingested_entity, column.name[:-len('_id')]).id else: value = getattr(ingested_entity, column.name) # Cast to the type because array types aren't deduced properly. query = query.filter(column == cast(value, column.type)) table_entity: Optional[JusticeCountsBase] = query.first() if table_entity is not None: # TODO(#4477): Instead of assuming the primary key field is named `id`, use an Entity method. ingested_entity.id = table_entity.id # TODO(#4477): Merging here doesn't seem perfect, although it should work so long as the given entity always has # all the properties set explicitly. To avoid the merge, the method could instead take in the entity class as # one parameter and the parameters to construct it separately and then query based on those parameters. However # this would likely make mypy less useful. merged_entity = session.merge(ingested_entity) return merged_entity session.add(ingested_entity) return ingested_entity
def update(self, entity_class, data, target=None): if not IEntity.providedBy(data): # pylint: disable=E1101 upd_ent = self.__run_traversal(entity_class, data, target, RELATION_OPERATIONS.UPDATE) else: upd_ent = SaSession.merge(self, data) return upd_ent
def _clear_stuck_queued_tasks(self, session: Session = NEW_SESSION) -> None: """ Tasks can get stuck in queued state in DB while still not in worker. This happens when the worker is autoscaled down and the task is queued but has not been picked up by any worker prior to the scaling. In such situation, we update the task instance state to scheduled so that it can be queued again. We chose to use task_adoption_timeout to decide when a queued task is considered stuck and should be reschelduled. """ if not isinstance(app.backend, DatabaseBackend): # We only want to do this for database backends where # this case has been spotted return # We use this instead of using bulk_state_fetcher because we # may not have the stuck task in self.tasks and we don't want # to clear task in self.tasks too session_ = app.backend.ResultSession() task_cls = getattr(app.backend, "task_cls", TaskDb) with session_cleanup(session_): celery_task_ids = [ t.task_id for t in session_.query(task_cls.task_id).filter( ~task_cls.status.in_( [celery_states.SUCCESS, celery_states.FAILURE])).all() ] self.log.debug("Checking for stuck queued tasks") max_allowed_time = utcnow() - self.task_adoption_timeout for task in session.query(TaskInstance).filter( TaskInstance.state == State.QUEUED, TaskInstance.queued_dttm < max_allowed_time): if task.key in self.queued_tasks or task.key in self.running: continue if task.external_executor_id in celery_task_ids: # The task is still running in the worker continue self.log.info( 'TaskInstance: %s found in queued state for more than %s seconds, rescheduling', task, self.task_adoption_timeout.total_seconds(), ) task.state = State.SCHEDULED session.merge(task)
def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession = NEW_SESSION): """ Helper method that set dag run state in the DB. :param dag_id: dag_id of target dag run :param run_id: run id of target dag run :param state: target state :param session: database session """ dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one() dag_run.state = state if state == State.RUNNING: dag_run.start_date = timezone.utcnow() dag_run.end_date = None else: dag_run.end_date = timezone.utcnow() session.merge(dag_run)
def load_chunk(session: Session, fetch_from: date, fetch_to: date, now: date, tan_callback): global error logging.info("Fetching TXs from %s to %s", fetch_from, fetch_to) error = False with hbci_client.get_account(tan_callback) as (conn, acc): def log_callback(_, response): if response.code[0] not in ('0', '1', '3'): # 0&1 info, 3 warning, rest err global error error = True conn.add_response_callback(log_callback) txs = conn.get_transactions(acc, fetch_from, fetch_to) while isinstance(txs, NeedTANResponse): logging.info("Calling tan callback for TXs %s", txs) tan = tan_callback(txs) if not tan: logging.error("No TAN got, aborting") return txs = conn.send_tan(txs, tan) logging.info("got txs: len: %s, type: %s, %s", len(txs), type(txs), txs) new_txs = [] for tx in txs: tx = Transaction(tx) if tx.entry_date > now: logging.warning("Ignoring future tx: %s", tx) continue fuzz = timedelta(days=config.import_fuzz_grace_days) if tx.date + fuzz < fetch_from or tx.date - fuzz > fetch_to: logging.warning("Ignoring tx which is not in requested range: %s", tx) continue new_txs.append(tx) logging.info("Fetched %d txs", len(new_txs)) if error: logging.error("Errors occurred") session.rollback() return False for tx in new_txs: session.merge(tx) return True
def upgrade(): session = Session(op.get_bind()) for company in session.query(Company).all(): datasource_config = DataSourceConfiguration( name=DATASOURCE_CONFIG_NAME, meta=DATASOURCE_META, company_id=company.id) session.add(datasource_config) for ds_entity in session.query(DataSource).filter( DataSource.company_id == company.id).all(): ds_entity.datasource_configuration_id = datasource_config.id ds_entity.meta = datasource_config.meta session.merge(ds_entity) op.alter_column('data_source', 'datasource_configuration_id', nullable=False)
def verify_dagruns( dag_runs: Iterable[DagRun], commit: bool, state: DagRunState, session: SASession, current_task: BaseOperator, ): """Verifies integrity of dag_runs. :param dag_runs: dag runs to verify :param commit: whether dag runs state should be updated :param state: state of the dag_run to set if commit is True :param session: session to use :param current_task: current task :return: """ for dag_run in dag_runs: dag_run.dag = current_task.subdag dag_run.verify_integrity() if commit: dag_run.state = state session.merge(dag_run)
def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None: """ For the DAGs in the given DagBag, record any associated configuration warnings and clear warnings for files that no longer have them. These are usually displayed through the Airflow UI so that users know that there are issues parsing DAGs. :param session: session for ORM operations :param dagbag: DagBag containing DAGs with configuration warnings """ self._validate_task_pools(dagbag=dagbag) stored_warnings = set( session.query(DagWarning).filter( DagWarning.dag_id.in_(dagbag.dags.keys())).all()) for warning_to_delete in stored_warnings - self.dag_warnings: session.delete(warning_to_delete) for warning_to_add in self.dag_warnings: session.merge(warning_to_add) session.commit()
def add_and_scan_channel( session: Session, channel: schema.Channel, autodownload: bool, dryrun: bool, ) -> None: channel.autodownload = autodownload session.merge(channel) try: videos = retrieve_videos(channel) except Exception as e: print(f"Couldn't retrieve videos: {e}", file=sys.stderr) return for v in videos: v.downloaded = 1 session.merge(v) if not dryrun: session.commit() print(f'Subscribed to "{channel.name}"')
def _set_dag_run_state(dag_id: str, execution_date: datetime, state: TaskInstanceState, session: SASession = NEW_SESSION): """ Helper method that set dag run state in the DB. :param dag_id: dag_id of target dag run :param execution_date: the execution date from which to start looking :param state: target state :param session: database session """ dag_run = (session.query(DagRun).filter( DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one()) dag_run.state = state if state == TaskInstanceState.RUNNING: dag_run.start_date = timezone.utcnow() dag_run.end_date = None else: dag_run.end_date = timezone.utcnow() session.merge(dag_run)
def _add_item(config: Config, session: Session, item_path: Path): """Adds a LibItem to the library from a given path. Args: config: Moe config. session: Current db session. item_path: Filesystem path of the item. Raises: AddError: Unable to add the item to the library. """ item: LibItem if item_path.is_file(): item = _add_track(item_path) old_album = item.album_obj elif item_path.is_dir(): item = _add_album(item_path) old_album = item else: raise AddError(f"Path not found: {item_path}") old_album.merge(old_album.get_existing(session), overwrite_album_info=False) new_albums = config.plugin_manager.hook.import_album(config=config, session=session, album=old_album) if new_albums: add_album = prompt.run_prompt(config, session, old_album, new_albums[0]) else: add_album = old_album if add_album: config.plugin_manager.hook.pre_add(config=config, session=session, album=add_album) add_album = session.merge(add_album)
def expand_mapped_task(self, run_id: str, *, session: Session) -> Sequence["TaskInstance"]: """Create the mapped task instances for mapped task. :return: The mapped task instances, in ascending order by map index. """ from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook total_length = functools.reduce( operator.mul, self._get_map_lengths(run_id, session=session).values()) state: Optional[TaskInstanceState] = None unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index == -1, or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), ).one_or_none()) ret: List[TaskInstance] = [] if unmapped_ti: # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length < 1: # If the upstream maps this to a zero-length value, simply marked the # unmapped task instance as SKIPPED (if needed). self.log.info( "Marking %s as SKIPPED since the map has %d values to expand", unmapped_ti, total_length, ) unmapped_ti.state = TaskInstanceState.SKIPPED session.flush() return ret # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 state = unmapped_ti.state self.log.debug("Updated in place to become %s", unmapped_ti) ret.append(unmapped_ti) indexes_to_map = range(1, total_length) else: # Only create "missing" ones. current_max_mapping = (session.query( func.max(TaskInstance.map_index)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, ).scalar()) indexes_to_map = range(current_max_mapping + 1, total_length) for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) # type: ignore self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) ti.task = self ret.append(ti) # Set to "REMOVED" any (old) TaskInstances with map indices greater # than the current map value session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index >= total_length, ).update({TaskInstance.state: TaskInstanceState.REMOVED}) session.flush() return ret
def manage_slas(self, dag: DAG, session: Session = None) -> None: """ Finding all tasks that have SLAs defined, and sending alert emails where needed. New SLA misses are also recorded in the database. We are assuming that the scheduler runs often, so we only check for tasks that should have succeeded in the past hour. """ self.log.info("Running SLA Checks for %s", dag.dag_id) if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): self.log.info( "Skipping SLA check for %s because no tasks in DAG have SLAs", dag) return qry = (session.query( TI.task_id, func.max(TI.execution_date).label('max_ti')).with_hint( TI, 'USE INDEX (PRIMARY)', dialect_name='mysql').filter(TI.dag_id == dag.dag_id).filter( or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED)).filter( TI.task_id.in_(dag.task_ids)).group_by( TI.task_id).subquery('sq')) max_tis: List[TI] = (session.query(TI).filter( TI.dag_id == dag.dag_id, TI.task_id == qry.c.task_id, TI.execution_date == qry.c.max_ti, ).all()) ts = timezone.utcnow() for ti in max_tis: task = dag.get_task(ti.task_id) if task.sla and not isinstance(task.sla, timedelta): raise TypeError( f"SLA is expected to be timedelta object, got " f"{type(task.sla)} in {task.dag_id}:{task.task_id}") dttm = dag.following_schedule(ti.execution_date) while dttm < timezone.utcnow(): following_schedule = dag.following_schedule(dttm) if following_schedule + task.sla < timezone.utcnow(): session.merge( SlaMiss(task_id=ti.task_id, dag_id=ti.dag_id, execution_date=dttm, timestamp=ts)) dttm = dag.following_schedule(dttm) session.commit() # pylint: disable=singleton-comparison slas: List[SlaMiss] = ( session.query(SlaMiss).filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa .all()) # pylint: enable=singleton-comparison if slas: # pylint: disable=too-many-nested-blocks sla_dates: List[datetime.datetime] = [ sla.execution_date for sla in slas ] fetched_tis: List[TI] = (session.query(TI).filter( TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id).all()) blocking_tis: List[TI] = [] for ti in fetched_tis: if ti.task_id in dag.task_ids: ti.task = dag.get_task(ti.task_id) blocking_tis.append(ti) else: session.delete(ti) session.commit() task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas) blocking_task_list = "\n".join(ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis) # Track whether email or any alert notification sent # We consider email or the alert callback as notifications email_sent = False notification_sent = False if dag.sla_miss_callback: # Execute the alert callback self.log.info('Calling SLA miss callback') try: dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis) notification_sent = True except Exception: # pylint: disable=broad-except self.log.exception( "Could not call sla_miss_callback for DAG %s", dag.dag_id) email_content = f"""\ Here's a list of tasks that missed their SLAs: <pre><code>{task_list}\n<code></pre> Blocking tasks: <pre><code>{blocking_task_list}<code></pre> Airflow Webserver URL: {conf.get(section='webserver', key='base_url')} """ tasks_missed_sla = [] for sla in slas: try: task = dag.get_task(sla.task_id) except TaskNotFound: # task already deleted from DAG, skip it self.log.warning( "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.", sla.task_id) continue tasks_missed_sla.append(task) emails: Set[str] = set() for task in tasks_missed_sla: if task.email: if isinstance(task.email, str): emails |= set(get_email_address_list(task.email)) elif isinstance(task.email, (list, tuple)): emails |= set(task.email) if emails: try: send_email(emails, f"[airflow] SLA miss on DAG={dag.dag_id}", email_content) email_sent = True notification_sent = True except Exception: # pylint: disable=broad-except Stats.incr('sla_email_notification_failure') self.log.exception( "Could not send SLA Miss email notification for DAG %s", dag.dag_id) # If we sent any notification, update the sla_miss table if notification_sent: for sla in slas: sla.email_sent = email_sent sla.notification_sent = True session.merge(sla) session.commit()
def add(cls, session: Session, mpid: int) -> None: log.debug("Adding opt-out for MPID {}".format(mpid)) newthing = cls(mpid=mpid) session.merge(newthing)
def update_state( self, session: Session = None, execute_callbacks: bool = True ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :type session: Session :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked directly (default: true) or recorded as a pending request in the ``callback`` property :type execute_callbacks: bool :return: Tuple containing tis that can be scheduled in the current loop & `callback` that needs to be executed """ # Callback to execute in case of Task Failures callback: Optional[callback_requests.DagCallbackRequest] = None start_dttm = timezone.utcnow() self.last_scheduling_decision = start_dttm with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"): dag = self.get_dag() info = self.task_instance_scheduling_decisions(session) tis = info.tis schedulable_tis = info.schedulable_tis changed_tis = info.changed_tis finished_tasks = info.finished_tasks unfinished_tasks = info.unfinished_tasks none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) none_task_concurrency = all(t.task.max_active_tis_per_dag is None for t in unfinished_tasks) none_deferred = all(t.state != State.DEFERRED for t in unfinished_tasks) if unfinished_tasks and none_depends_on_past and none_task_concurrency and none_deferred: # small speed up are_runnable_tasks = (schedulable_tis or self._are_premature_tis( unfinished_tasks, finished_tasks, session) or changed_tis) leaf_task_ids = {t.task_id for t in dag.leaves} leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids] # if all roots finished and at least one failed, the run failed if not unfinished_tasks and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis): self.log.error('Marking run %s failed', self) self.set_state(State.FAILED) if execute_callbacks: dag.handle_callback(self, success=False, reason='task_failure', session=session) elif dag.has_on_failure_callback: callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=True, msg='task_failure', ) # if all leaves succeeded and no unfinished tasks, the run succeeded elif not unfinished_tasks and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis): self.log.info('Marking run %s successful', self) self.set_state(State.SUCCESS) if execute_callbacks: dag.handle_callback(self, success=True, reason='success', session=session) elif dag.has_on_success_callback: callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=False, msg='success', ) # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and none_task_concurrency and none_deferred and not are_runnable_tasks): self.log.error('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) if execute_callbacks: dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) elif dag.has_on_failure_callback: callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=True, msg='all_tasks_deadlocked', ) # finally, if the roots aren't done, the dag is still running else: self.set_state(State.RUNNING) self._emit_true_scheduling_delay_stats_for_finished_state( finished_tasks) self._emit_duration_stats_for_finished_state() session.merge(self) return schedulable_tis, callback
def expand_mapped_task( self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]: """Create the mapped task instances for mapped task. :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the maximum map_index. """ from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook total_length: Optional[int] try: total_length = self._get_specified_expand_input( ).get_total_map_length(run_id, session=session) except NotFullyPopulated as e: self.log.info( "Cannot expand %r for run %s; missing upstream values: %s", self, run_id, sorted(e.missing), ) total_length = None state: Optional[TaskInstanceState] = None unmapped_ti: Optional[TaskInstance] = ( session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index == -1, or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)), ).one_or_none()) all_expanded_tis: List[TaskInstance] = [] if unmapped_ti: # The unmapped task instance still exists and is unfinished, i.e. we # haven't tried to run it before. if total_length is None: # If the map length cannot be calculated (due to unavailable # upstream sources), fail the unmapped task. unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED indexes_to_map: Iterable[int] = () elif total_length < 1: # If the upstream maps this to a zero-length value, simply mark # the unmapped task instance as SKIPPED (if needed). self.log.info( "Marking %s as SKIPPED since the map has %d values to expand", unmapped_ti, total_length, ) unmapped_ti.state = TaskInstanceState.SKIPPED indexes_to_map = () else: # Otherwise convert this into the first mapped index, and create # TaskInstance for other indexes. unmapped_ti.map_index = 0 self.log.debug("Updated in place to become %s", unmapped_ti) all_expanded_tis.append(unmapped_ti) indexes_to_map = range(1, total_length) state = unmapped_ti.state elif not total_length: # Nothing to fixup. indexes_to_map = () else: # Only create "missing" ones. current_max_mapping = (session.query( func.max(TaskInstance.map_index)).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, ).scalar()) indexes_to_map = range(current_max_mapping + 1, total_length) for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. ti = TaskInstance(self, run_id=run_id, map_index=index, state=state) self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) ti.refresh_from_task( self) # session.merge() loses task information. all_expanded_tis.append(ti) # Coerce the None case to 0 -- these two are almost treated identically, # except the unmapped ti (if exists) is marked to different states. total_expanded_ti_count = total_length or 0 # Set to "REMOVED" any (old) TaskInstances with map indices greater # than the current map value session.query(TaskInstance).filter( TaskInstance.dag_id == self.dag_id, TaskInstance.task_id == self.task_id, TaskInstance.run_id == run_id, TaskInstance.map_index >= total_expanded_ti_count, ).update({TaskInstance.state: TaskInstanceState.REMOVED}) session.flush() return all_expanded_tis, total_expanded_ti_count - 1
def upgrade(): session = Session(bind=op.get_bind(), expire_on_commit=False) # All IRONMAN orgs need a 3 digit version IRONMAN_system = 'http://pcctc.org/' ironman_org_ids = [(id.id, id._value) for id in Identifier.query.filter( Identifier.system == IRONMAN_system).with_entities( Identifier.id, Identifier._value)] existing_values = [id[1] for id in ironman_org_ids] replacements = {} for io_id, io_value in ironman_org_ids: found = org_pattern.match(io_value) if found: # avoid probs if run again - don't add if already present needed = '146-0{}'.format(found.group(1)) replacements[found.group(1)] = '0{}'.format(found.group(1)) if needed not in existing_values: needed_i = Identifier( use='secondary', system=IRONMAN_system, _value=needed) else: needed_i = Identifier.query.filter( Identifier.system == IRONMAN_system).filter( Identifier._value == needed).one() # add a 3 digit identifier and link with same org oi = OrganizationIdentifier.query.filter( OrganizationIdentifier.identifier_id == io_id).one() needed_oi = OrganizationIdentifier.query.filter( OrganizationIdentifier.organization_id == oi.organization_id).filter( OrganizationIdentifier.identifier == needed_i).first() if not needed_oi: needed_i = session.merge(needed_i) needed_oi = OrganizationIdentifier( organization_id=oi.organization_id, identifier=needed_i) session.add(needed_oi) # All IRONMAN users with a 2 digit ID referencing one of the replaced # values needs a 3 digit version ironman_study_ids = Identifier.query.filter( Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter( Identifier._value.like('170-%')).with_entities( Identifier.id, Identifier._value) for iid, ival in ironman_study_ids: found = study_pattern.match(ival) if found: org_segment = found.group(1) patient_segment = found.group(2) # only add if also one of the new org ids if org_segment not in replacements: continue needed = '170-{}-{}'.format( replacements[org_segment], patient_segment) # add a 3 digit identifier and link with same user(s), # if not already present uis = UserIdentifier.query.filter( UserIdentifier.identifier_id == iid) needed_i = Identifier.query.filter( Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter( Identifier._value == needed).first() if not needed_i: needed_i = Identifier( use='secondary', system=TRUENTH_EXTERNAL_STUDY_SYSTEM, _value=needed) for ui in uis: needed_ui = UserIdentifier.query.filter( UserIdentifier.user_id == ui.user_id).filter( UserIdentifier.identifier == needed_i).first() if not needed_ui: needed_ui = UserIdentifier( user_id=ui.user_id, identifier=needed_i) session.add(needed_ui) session.commit()
def merge(self, entity, load=True): self.begin() SaSession.merge(self, entity, load=load) self.commit()
class CrawlProcessor(object): __VERSION__ = "CrawlProcessor-0.2.1" def __init__(self, engine, redis_server, stop_list="keyword_filter.txt"): if type(engine) == types.StringType: logging.info("Using connection string '%s'" % (engine,)) new_engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED") if "sqlite:" in engine: logging.debug("Setting text factory for unicode compat.") new_engine.raw_connection().connection.text_factory = str self._engine = new_engine else: logging.info("Using existing engine...") self._engine = engine logging.info("Binding session...") self._session = Session(bind=self._engine, autocommit = False) if type(stop_list) == types.StringType: stop_list_fp = open(stop_list) else: stop_list_fp = stop_list self.stop_list = set([]) for line in stop_list_fp: self.stop_list.add(line.strip()) self.cls = DocumentClassifier() self.dc = DomainController(self._engine, self._session) self.ac = ArticleController(self._engine, self._session) self.ex = extract.TermExtractor() self.kwc = KeywordController(self._engine, self._session) self.swc = SoftwareVersionsController(self._engine, self._session) self.redis_kw = redis.Redis(host=redis_server, port=6379, db=1) self.redis_dm = redis.Redis(host=redis_server, port=6379, db=2) dm_session = Session(bind=self._engine, autocommit = False) self.drw = DomainResolutionWorker(dm_session, self.redis_dm) def _check_processed(self, item): crawl_id, record = item headers, content, url, date_crawled, content_type = record path = self.ac.get_path_fromurl(url) domain_identifier = None logging.info("_check_processed: retrieving domain...") domain_key = self.dc.get_Domain_key(url) while domain_identifier == None: domain_identifier = self.drw.get_domain(domain_key) it = self._session.query(Article).filter_by(crawl_id = crawl_id).filter_by(domain_id = domain_identifier).filter_by(path = path) try: it = it.one() logging.error("%s: already processed", url) return False except sqlalchemy.orm.exc.MultipleResultsFound: logging.error("%s: appears to have been already processed multiple times", url) return False except sqlalchemy.orm.exc.NoResultFound: logging.info("%s: hasn't been processed yet", url) return True def process_record(self, item): if len(item) != 2: raise ValueError(item) if not self._check_processed(item): return None ret, retries = None, 2 while ret == None and retries > 0: try: retries -= 1 ret = self._process_record(item) except Exception as ex: import traceback print >> sys.stderr, ex traceback.print_exc() raise ex if ret == False: return None return ret def _process_record(self, item_arg): crawl_id, record = item_arg headers, content, url, date_crawled, content_type = record assert headers is not None assert content is not None assert url is not None assert date_crawled is not None assert content_type is not None status = "Processed" # Fix for a seg-fault if "nasa.gov" in url: return False # Sort out the domain domain_identifier = None logging.info("Retrieving domain...") domain_key = self.dc.get_Domain_key(url) while domain_identifier == None: domain_identifier = self.drw.get_domain(domain_key) domain = self._session.query(Domain).get(domain_identifier) assert domain is not None # Build database objects path = self.ac.get_path_fromurl(url) article = Article(path, date_crawled, crawl_id, domain, status) self._session.add(article) classified_by = self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__) assert classified_by is not None if content_type != 'text/html': logging.error("Unsupported content type: %s", str(content_type)) article.status = "UnsupportedType" return False # Start the async transaction to get the plain text worker_req_thread = BoilerPipeWorker(content) worker_req_thread.start() # Whilst that's executing, parse the document logging.info("Parsing HTML...") html = BeautifulSoup(content) if html is None or html.body is None: article.status = "NoContent" return False # Extract the dates date_dict = pydate.get_dates(html) if len(date_dict) == 0: status = "NoDates" # Detect the language lang, lang_certainty = langid.classify(content) # Wait for the BoilerPipe thread to complete worker_req_thread.join() logging.debug(worker_req_thread.result) logging.debug(worker_req_thread.version) if worker_req_thread.result == None: article.status = "NoContent" return False # If the language isn't English, skip it if lang != "en": logging.info("language: %s with certainty %.2f - skipping...", lang, lang_certainty) article.status = "LanguageError" # Replace with something appropriate return False content = worker_req_thread.result.encode('ascii', 'ignore') # Headline extraction h_counter = 6 headline = None while h_counter > 0: tag = "h%d" % (h_counter,) found = False for node in html.findAll(tag): if node.text in content: headline = node.text found = True break if found: break h_counter -= 1 # Run keyword extraction keywords = self.ex(content) kset = KeywordSet(self.stop_list) nnp_sets_scored = set([]) for word, freq, amnt in sorted(keywords): try: nnp_sets_scored.add((word, freq)) except ValueError: break nnp_adj = set([]) nnp_set = set([]) nnp_vector = [] for sentence in sent_tokenize(content): text = nltk.word_tokenize(sentence) pos = nltk.pos_tag(text) pos_groups = itertools.groupby(pos, lambda x: x[1]) for k, g in pos_groups: if k != 'NNP': continue nnp_list = [word for word, speech in g] nnp_buf = [] for item in nnp_list: nnp_set.add(item) nnp_buf.append(item) nnp_vector.append(item) for i, j in zip(nnp_buf[0:-1], nnp_buf[1:]): nnp_adj.add((i, j)) nnp_vector = filter(lambda x: x.lower() not in self.stop_list, nnp_vector) nnp_counter = Counter(nnp_vector) for word in nnp_set: score = nnp_counter[word] nnp_sets_scored.add((item, score)) for item, score in sorted(nnp_sets_scored, key=lambda x: x[1], reverse=True): try: if type(item) == types.ListType or type(item) == types.TupleType: kset.add(' '.join(item)) else: kset.add(item) except ValueError: break scored_nnp_adj = [] for item1, item2 in nnp_adj: score = nnp_counter[item1] + nnp_counter[item2] scored_nnp_adj.append((item1, item2, score)) nnp_adj = [] for item1, item2, score in sorted(scored_nnp_adj, key=lambda x: x[1], reverse=True): if len(nnp_adj) < KEYWORD_LIMIT: nnp_adj.append((item1, item2)) else: break # Generate list of all keywords keywords = set([]) for keyword in kset: try: k = Keyword(keyword) keywords.add(k) except ValueError as ex: logging.error(ex) continue for item1, item2 in nnp_adj: try: k = Keyword(item1) keywords.add(k) except ValueError as ex: logging.error(ex) try: k = Keyword(item2) keywords.add(k) except ValueError as ex: logging.error(ex) # Resolve keyword identifiers keyword_resolution_worker = KeywordResolutionWorker(set([k.word for k in keywords]), self.redis_kw) keyword_resolution_worker.start() # Run sentiment analysis trace = [] features = self.cls.classify(worker_req_thread.result, trace) label, length, classified, pos_sentences, neg_sentences,\ pos_phrases, neg_phrases = features[0:7] # Convert Pysen's model into database models try: doc = Document(article.id, label, length, pos_sentences, neg_sentences, pos_phrases, neg_phrases, headline) except ValueError as ex: logging.error(ex) logging.error("Skipping this document...") article.status = "ClassificationError" return False self._session.add(doc) extracted_phrases = set([]) for sentence, score, phrase_trace in trace: sentence_type = "Unknown" for node in html.findAll(text=True): if sentence.text in node.strip(): sentence_type = node.parent.name.upper() break if sentence_type not in ["H1", "H2", "H3", "H4", "H5", "H6", "P", "Unknown"]: sentence_type = "Other" label, average, prob, pos, neg, probs, _scores = score s = Sentence(doc, label, average, prob, sentence_type) self._session.add(s) for phrase, prob, score, label in phrase_trace: p = Phrase(s, score, prob, label) self._session.add(p) extracted_phrases.add((phrase, p)) # Wait for keyword resolution to finish keyword_resolution_worker.join() keyword_mapping = keyword_resolution_worker.out_keywords # Associate extracted keywords with phrases keyword_objects, short_keywords = kset.convert(keyword_mapping, self.kwc) for k in keyword_objects: self._session.merge(k) for p, p_obj in extracted_phrases: for k in keyword_objects: if k.word in p.get_text(): nk = KeywordIncidence(k, p_obj) # Save the keyword adjacency list for i, j in kset.convert_adj_tuples(nnp_adj, keyword_mapping, self.kwc): self._session.merge(i) self._session.merge(j) kwa = KeywordAdjacency(i, j, doc) self._session.add(kwa) # Build date objects for key in date_dict: rec = date_dict[key] if "dates" not in rec: logging.error("OK: 'dates' is not in a pydate result record.") continue dlen = len(rec["dates"]) if rec["text"] not in content: logging.debug("'%s' is not in %s", rec["text"], content) continue if dlen > 1: for date, day_first, year_first in rec["dates"]: try: dobj = AmbiguousDate(date, doc, day_first, year_first, rec["prep"], key) except ValueError as ex: logging.error(ex) continue self._session.add(dobj) elif dlen == 1: for date, day_first, year_first in rec["dates"]: dobj = CertainDate(date, doc, key) self._session.add(dobj) else: logging.error("'dates' in a pydate result set contains no records.") # Process links for link in html.findAll('a'): if not link.has_attr("href"): logging.debug("skipping %s: no href", link) continue process = True for node in link.findAll(text=True): if node not in worker_req_thread.result: process = False break if not process: logging.debug("skipping %s because it's not in the body text", link) break href, junk, junk = link["href"].partition("#") if "http://" in href: try: domain_id = None domain_key = self.dc.get_Domain_key(href) while domain_id is None: domain_id = self.drw.get_domain(domain_key) assert domain_id is not None href_domain = self._session.query(Domain).get(domain_id) except ValueError as ex: logging.error(ex) logging.error("Skipping this link") continue href_path = self.ac.get_path_fromurl(href) lnk = AbsoluteLink(doc, href_domain, href_path) self._session.add(lnk) logging.debug("Adding: %s", lnk) else: href_path = href try: lnk = RelativeLink(doc, href_path) except ValueError as ex: logging.error(ex) logging.error("Skipping link") continue self._session.add(lnk) logging.debug("Adding: %s", lnk) # Construct software involvment records self_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(self.__VERSION__), "Processed", doc) date_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pydate.__VERSION__), "Dated", doc) clas_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__), "Classified", doc) extr_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(worker_req_thread.version), "Extracted", doc) for sw in [self_sir, date_sir, clas_sir, extr_sir]: self._session.merge(sw, load=True) logging.debug("Domain: %s", domain) logging.debug("Path: %s", path) article.status = status # Commit to database, return True on success try: self._session.commit() except OperationalError as ex: logging.error(ex) self._session.rollback() return None return article.id def finalize(self): self._session.commit()
def verify_integrity(self, session: Session = NEW_SESSION): """ Verifies the DagRun by checking for removed tasks or tasks that are not in the database yet. It will set state to removed or add the task if required. :param session: Sqlalchemy ORM Session :type session: Session """ from airflow.settings import task_instance_mutation_hook dag = self.get_dag() tis = self.get_task_instances(session=session) # check for removed or restored tasks task_ids = set() for ti in tis: task_instance_mutation_hook(ti) task_ids.add(ti.task_id) task = None try: task = dag.get_task(ti.task_id) except AirflowException: if ti.state == State.REMOVED: pass # ti has already been removed, just ignore it elif self.state != State.RUNNING and not dag.partial: self.log.warning( "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag) Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1) ti.state = State.REMOVED should_restore_task = (task is not None) and ti.state == State.REMOVED if should_restore_task: self.log.info( "Restoring task '%s' which was previously removed from DAG '%s'", ti, dag) Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1) ti.state = State.NONE session.merge(ti) def task_filter(task: "BaseOperator"): return task.task_id not in task_ids and ( self.is_backfill or task.start_date <= self.execution_date) created_counts: Dict[str, int] = defaultdict(int) # Set for the empty default in airflow.settings -- if it's not set this means it has been changed hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False) if hook_is_noop: def create_ti_mapping(task: "BaseOperator"): created_counts[task.task_type] += 1 return TI.insert_mapping(self.run_id, task) else: def create_ti(task: "BaseOperator") -> TI: ti = TI(task, run_id=self.run_id) task_instance_mutation_hook(ti) created_counts[ti.operator] += 1 return ti # Create missing tasks tasks = list(filter(task_filter, dag.task_dict.values())) try: if hook_is_noop: session.bulk_insert_mappings(TI, map(create_ti_mapping, tasks)) else: session.bulk_save_objects(map(create_ti, tasks)) for task_type, count in created_counts.items(): Stats.incr(f"task_instance_created-{task_type}", count) session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info( 'Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback()
def update_state( self, session: Session = None, execute_callbacks: bool = True ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :type session: Session :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked directly (default: true) or recorded as a pending request in the ``callback`` property :type execute_callbacks: bool :return: Tuple containing tis that can be scheduled in the current loop & `callback` that needs to be executed """ # Callback to execute in case of Task Failures callback: Optional[callback_requests.DagCallbackRequest] = None start_dttm = timezone.utcnow() self.last_scheduling_decision = start_dttm dag = self.get_dag() ready_tis: List[TI] = [] tis = list( self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN, ))) self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) for ti in tis: ti.task = dag.get_task(ti.task_id) unfinished_tasks = [t for t in tis if t.state in State.unfinished()] finished_tasks = [ t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED] ] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) if unfinished_tasks: scheduleable_tasks = [ ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES ] self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks)) ready_tis, changed_tis = self._get_ready_tis( scheduleable_tasks, finished_tasks, session) self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis)) if none_depends_on_past and none_task_concurrency: # small speed up are_runnable_tasks = ready_tis or self._are_premature_tis( unfinished_tasks, finished_tasks, session) or changed_tis duration = (timezone.utcnow() - start_dttm) Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration) leaf_task_ids = {t.task_id for t in dag.leaves} leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids] # if all roots finished and at least one failed, the run failed if not unfinished_tasks and any( leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis): self.log.error('Marking run %s failed', self) self.set_state(State.FAILED) if execute_callbacks: dag.handle_callback(self, success=False, reason='task_failure', session=session) else: callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=True, msg='task_failure') # if all leafs succeeded and no unfinished tasks, the run succeeded elif not unfinished_tasks and all( leaf_ti.state in {State.SUCCESS, State.SKIPPED} for leaf_ti in leaf_tis): self.log.info('Marking run %s successful', self) self.set_state(State.SUCCESS) if execute_callbacks: dag.handle_callback(self, success=True, reason='success', session=session) else: callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=False, msg='success') # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks): self.log.error('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) if execute_callbacks: dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) else: callback = callback_requests.DagCallbackRequest( full_filepath=dag.fileloc, dag_id=self.dag_id, execution_date=self.execution_date, is_failure_callback=True, msg='all_tasks_deadlocked') # finally, if the roots aren't done, the dag is still running else: self.set_state(State.RUNNING) self._emit_duration_stats_for_finished_state() session.merge(self) return ready_tis, callback
def update_state(self, session: Session = None) -> List[TI]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :type session: Session :return: ready_tis: the tis that can be scheduled in the current loop :rtype ready_tis: list[airflow.models.TaskInstance] """ dag = self.get_dag() ready_tis: List[TI] = [] tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,))) self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) for ti in tis: ti.task = dag.get_task(ti.task_id) start_dttm = timezone.utcnow() unfinished_tasks = [t for t in tis if t.state in State.unfinished()] finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) if unfinished_tasks: scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] self.log.debug( "number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks)) ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session) self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis)) if none_depends_on_past and none_task_concurrency: # small speed up are_runnable_tasks = ready_tis or self._are_premature_tis( unfinished_tasks, finished_tasks, session) or changed_tis duration = (timezone.utcnow() - start_dttm) Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration) leaf_task_ids = {t.task_id for t in dag.leaves} leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids] # if all roots finished and at least one failed, the run failed if not unfinished_tasks and any( leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis ): self.log.error('Marking run %s failed', self) self.set_state(State.FAILED) dag.handle_callback(self, success=False, reason='task_failure', session=session) # if all leafs succeeded and no unfinished tasks, the run succeeded elif not unfinished_tasks and all( leaf_ti.state in {State.SUCCESS, State.SKIPPED} for leaf_ti in leaf_tis ): self.log.info('Marking run %s successful', self) self.set_state(State.SUCCESS) dag.handle_callback(self, success=True, reason='success', session=session) # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks): self.log.error('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) # finally, if the roots aren't done, the dag is still running else: self.set_state(State.RUNNING) self._emit_duration_stats_for_finished_state() # todo: determine we want to use with_for_update to make sure to lock the run session.merge(self) session.commit() return ready_tis
class InboxSession(object): """ Inbox custom ORM (with SQLAlchemy compatible API). Parameters ---------- engine : <sqlalchemy.engine.Engine> A configured database engine to use for this session versioned : bool Do you want to enable the transaction log? ignore_soft_deletes : bool Whether or not to ignore soft-deleted objects in query results. namespace_id : int Namespace to limit query results with. """ def __init__(self, engine, versioned=True, ignore_soft_deletes=True, namespace_id=None): # TODO: support limiting on namespaces assert engine, "Must set the database engine" args = dict(bind=engine, autoflush=True, autocommit=False) self.ignore_soft_deletes = ignore_soft_deletes if ignore_soft_deletes: args['query_cls'] = InboxQuery self._session = Session(**args) if versioned: from inbox.models.transaction import create_revisions @event.listens_for(self._session, 'after_flush') def after_flush(session, flush_context): """ Hook to log revision snapshots. Must be post-flush in order to grab object IDs on new objects. """ create_revisions(session) def query(self, *args, **kwargs): q = self._session.query(*args, **kwargs) if self.ignore_soft_deletes: return q.options(IgnoreSoftDeletesOption()) else: return q def add(self, instance): if not self.ignore_soft_deletes or not instance.is_deleted: self._session.add(instance) else: raise Exception("Why are you adding a deleted object?") def add_all(self, instances): if True not in [i.is_deleted for i in instances] or \ not self.ignore_soft_deletes: self._session.add_all(instances) else: raise Exception("Why are you adding a deleted object?") def delete(self, instance): if self.ignore_soft_deletes: instance.mark_deleted() # just to make sure self._session.add(instance) else: self._session.delete(instance) def begin(self): self._session.begin() def commit(self): self._session.commit() def rollback(self): self._session.rollback() def flush(self): self._session.flush() def close(self): self._session.close() def expunge(self, obj): self._session.expunge(obj) def merge(self, obj): return self._session.merge(obj) @property def no_autoflush(self): return self._session.no_autoflush