Exemplo n.º 1
0
    def verify_integrity(self, session: Session = None):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state is not State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. "
                        "Marking it as removed.", ti, dag)
                    Stats.incr("task_removed_from_dag.{}".format(dag.dag_id),
                               1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously "
                    "removed from DAG '%s'", ti, dag)
                Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr("task_instance_created-{}".format(task.task_type),
                           1, 1)
                ti = TI(task, self.execution_date)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.commit()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info('Hit IntegrityError while creating the TIs for '
                          f'{dag.dag_id} - {self.execution_date}.')
            self.log.info('Doing session rollback.')
            session.rollback()
Exemplo n.º 2
0
    def test_set_object_local_id(self):
        """
            Test the method _set_object_local_id(self, obj, local_id)
            Test scenario:
            Set the local_id of the specified object when the pvc_id is none
        """

        obj_id = self.powerVCMapping.id
        self.powerVCMapping.pvc_id = None
        self.powerVCMapping.local_id = None
        self.powerVCMapping.status = None

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(id=obj_id).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.StubOutWithMock(session, 'merge')
        session.merge(self.powerVCMapping).AndReturn("")

        self.aMox.ReplayAll()
        self.powervcagentdb._set_object_local_id(self.powerVCMapping, 'test')
        self.aMox.VerifyAll()
        self.assertEqual(self.powerVCMapping.status, 'Creating')
        self.assertEqual(self.powerVCMapping.local_id, 'test')
        self.aMox.UnsetStubs()
Exemplo n.º 3
0
    def test_set_object_local_id(self):
        """
            Test the method _set_object_local_id(self, obj, local_id)
            Test scenario:
            Set the local_id of the specified object when the pvc_id is none
        """

        obj_id = self.powerVCMapping.id
        self.powerVCMapping.pvc_id = None
        self.powerVCMapping.local_id = None
        self.powerVCMapping.status = None

        self.aMox.StubOutWithMock(session, 'query')
        session.query(model.PowerVCMapping).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'filter_by')
        query.filter_by(id=obj_id).AndReturn(query)

        self.aMox.StubOutWithMock(query, 'one')
        query.one().AndReturn(self.powerVCMapping)

        self.aMox.StubOutWithMock(session, 'merge')
        session.merge(self.powerVCMapping).AndReturn("")

        self.aMox.ReplayAll()
        self.powervcagentdb._set_object_local_id(self.powerVCMapping, 'test')
        self.aMox.VerifyAll()
        self.assertEqual(self.powerVCMapping.status, 'Creating')
        self.assertEqual(self.powerVCMapping.local_id, 'test')
        self.aMox.UnsetStubs()
Exemplo n.º 4
0
    def verify_integrity(self, session: Session = NEW_SESSION):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning("Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info("Restoring task '%s' which was previously removed from DAG '%s'", ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        # check for missing tasks
        for task in dag.task_dict.values():
            if task.start_date > self.execution_date and not self.is_backfill:
                continue

            if task.task_id not in task_ids:
                Stats.incr(f"task_instance_created-{task.task_type}", 1, 1)
                ti = TI(task, execution_date=None, run_id=self.run_id)
                task_instance_mutation_hook(ti)
                session.add(ti)

        try:
            session.flush()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id)
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
Exemplo n.º 5
0
def synchronize_log_template(*, session: Session = NEW_SESSION) -> None:
    """Synchronize log template configs with table.

    This checks if the last row fully matches the current config values, and
    insert a new row if not.
    """
    stored = session.query(LogTemplate).order_by(LogTemplate.id.desc()).first()
    filename = conf.get("logging", "log_filename_template")
    prefix = conf.get("logging", "task_log_prefix_template")
    if stored and stored.filename == filename and stored.task_prefix == prefix:
        return
    session.merge(LogTemplate(filename=filename, task_prefix=prefix))
Exemplo n.º 6
0
    def add(cls, session: Session, mpid: Union[int, str]) -> None:
        """
        Add a record of a patient who wishes to opt out.

        Args:
            session: SQLAlchemy database session for the secret admin database
            mpid: MPID of the patient who is opting out
        """
        log.debug(f"Adding opt-out for MPID {mpid}")
        # noinspection PyArgumentList
        newthing = cls(mpid=mpid)
        session.merge(newthing)
Exemplo n.º 7
0
def downgrade():

    session = Session(op.get_bind())

    op.alter_column('data_source',
                    'datasource_configuration_id',
                    nullable=True)
    for ds in session.query(DataSource).all():
        ds.datasource_configuration_id = None
        session.merge(ds)

    session.query(DataSourceConfiguration).delete()
Exemplo n.º 8
0
def synchronize_log_filename_template(*,
                                      session: Session = NEW_SESSION) -> None:
    """Synchronize log filename template config with table.

    This checks if the last row (based on timestamp) matches the current
    config value, and insert a new row if not.
    """
    stored = session.query(LogFilename.template).order_by(
        LogFilename.id.desc()).limit(1).scalar()
    config = conf.get("logging", "LOG_FILENAME_TEMPLATE")
    if stored == config:
        return
    session.merge(LogFilename(template=config))
Exemplo n.º 9
0
def insert_instruments(session: Session, dest_ric: Path, logger: Logger) -> None:
    with dest_ric.open(mode='r') as f:
        reader = csv.reader(f, delimiter=',')
        next(reader)
        for fields in reader:
            ric = fields[0]
            desc = fields[1]
            currency = fields[2]
            type_ = fields[3]
            exchange = fields[4] if type_ in [EQUITY, FUTURES] else None
            instrument = Instrument(ric, desc, currency, type_, exchange)
            session.merge(instrument)
    session.commit()
Exemplo n.º 10
0
def draw_card(session: Session, order: Order) -> str:
    """根据给定的订单随机抽取一张卡片.
    ### Args:
    ``session``: 用于连接数据库的SQLAlchemy线程.\n
    ``order``: 待抽卡的订单.\n
    ### Result:
    ``message``: 抽卡后的反馈信息.\n
    """
    if not hasattr(order, 'nickname'):
        order.nickname = ''
    # 抽取卡牌
    threshold = float(setting.read_config('card', 'threshold'))
    divend = order.amount / threshold
    rand = abs(random.normalvariate(0, math.sqrt(divend)))
    if divend > 25:
        rand += math.log2(divend / 25)
    if 1 < rand <= 2.5:
        rand = 1
    elif 2.5 < rand <= 5:
        rand = 2
    elif rand > 5:
        rand = 3
    rarity = int(rand)
    card_query = session.query(Card).filter(Card.rarity == rarity)
    card = card_query[random.randint(0, card_query.count() - 1)]
    # 按订单插入记录
    session.add(
        Card_Order(order_id=order.id, rarity=rarity, type_id=card.type_id))
    # 按用户插入记录
    user = find_user(session, order.platform, order.user_id, order.nickname)
    order.nickname = user.nickname
    session.merge(
        Card_User(user_id=user.id, rarity=rarity, type_id=card.type_id))
    session.flush()
    # 生成信息
    collected_cards = session.query(Card_User).\
        filter(Card_User.user_id == user.id).\
        filter(Card_User.rarity == rarity).count()
    info_dict = {
        'nickname': order.nickname,
        'rarity': setting.rarity()[card.rarity],
        'name': card.name,
        'context': card.context,
        'user_amount': collected_cards,
        'total_amount': card_query.count(),
        'image': f'[CQ:image,file={card.file_name}]',
    }
    logger.debug('%s抽取到一张%s卡:%s', user.nickname, info_dict['rarity'],
                 card.name)
    pattern = setting.read_config('card', 'pattern')
    return pattern.format(**info_dict)
Exemplo n.º 11
0
def _update_existing_or_create(ingested_entity: schema.JusticeCountsDatabaseEntity, session: Session) \
        -> schema.JusticeCountsDatabaseEntity:
    # Note: Using on_conflict_do_update to resolve whether there is an existing entity could be more efficient as it
    # wouldn't incur multiple roundtrips. However for some entities we need to know whether there is an existing entity
    # (e.g. table instance) so we can clear child entities, so we probably wouldn't win much if anything.
    table = ingested_entity.__table__
    [unique_constraint] = [
        constraint for constraint in table.constraints
        if isinstance(constraint, UniqueConstraint)
    ]
    query = session.query(table)
    for column in unique_constraint:
        # TODO(#4477): Instead of making an assumption about how the property name is formed from the column name, use
        # an Entity method here to follow the foreign key relationship.
        if column.name.endswith('_id'):
            value = getattr(ingested_entity, column.name[:-len('_id')]).id
        else:
            value = getattr(ingested_entity, column.name)
        # Cast to the type because array types aren't deduced properly.
        query = query.filter(column == cast(value, column.type))
    table_entity: Optional[JusticeCountsBase] = query.first()
    if table_entity is not None:
        # TODO(#4477): Instead of assuming the primary key field is named `id`, use an Entity method.
        ingested_entity.id = table_entity.id
        # TODO(#4477): Merging here doesn't seem perfect, although it should work so long as the given entity always has
        # all the properties set explicitly. To avoid the merge, the method could instead take in the entity class as
        # one parameter and the parameters to construct it separately and then query based on those parameters. However
        # this would likely make mypy less useful.
        merged_entity = session.merge(ingested_entity)
        return merged_entity
    session.add(ingested_entity)
    return ingested_entity
Exemplo n.º 12
0
 def update(self, entity_class, data, target=None):
     if not IEntity.providedBy(data): # pylint: disable=E1101
         upd_ent = self.__run_traversal(entity_class, data, target,
                                        RELATION_OPERATIONS.UPDATE)
     else:
         upd_ent = SaSession.merge(self, data)
     return upd_ent
Exemplo n.º 13
0
    def _clear_stuck_queued_tasks(self,
                                  session: Session = NEW_SESSION) -> None:
        """
        Tasks can get stuck in queued state in DB while still not in
        worker. This happens when the worker is autoscaled down and
        the task is queued but has not been picked up by any worker prior to the scaling.

        In such situation, we update the task instance state to scheduled so that
        it can be queued again. We chose to use task_adoption_timeout to decide when
        a queued task is considered stuck and should be reschelduled.
        """
        if not isinstance(app.backend, DatabaseBackend):
            # We only want to do this for database backends where
            # this case has been spotted
            return
        # We use this instead of using bulk_state_fetcher because we
        # may not have the stuck task in self.tasks and we don't want
        # to clear task in self.tasks too
        session_ = app.backend.ResultSession()
        task_cls = getattr(app.backend, "task_cls", TaskDb)
        with session_cleanup(session_):
            celery_task_ids = [
                t.task_id for t in session_.query(task_cls.task_id).filter(
                    ~task_cls.status.in_(
                        [celery_states.SUCCESS, celery_states.FAILURE])).all()
            ]
        self.log.debug("Checking for stuck queued tasks")

        max_allowed_time = utcnow() - self.task_adoption_timeout

        for task in session.query(TaskInstance).filter(
                TaskInstance.state == State.QUEUED,
                TaskInstance.queued_dttm < max_allowed_time):
            if task.key in self.queued_tasks or task.key in self.running:
                continue

            if task.external_executor_id in celery_task_ids:
                # The task is still running in the worker
                continue

            self.log.info(
                'TaskInstance: %s found in queued state for more than %s seconds, rescheduling',
                task,
                self.task_adoption_timeout.total_seconds(),
            )
            task.state = State.SCHEDULED
            session.merge(task)
Exemplo n.º 14
0
def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession = NEW_SESSION):
    """
    Helper method that set dag run state in the DB.

    :param dag_id: dag_id of target dag run
    :param run_id: run id of target dag run
    :param state: target state
    :param session: database session
    """
    dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
    dag_run.state = state
    if state == State.RUNNING:
        dag_run.start_date = timezone.utcnow()
        dag_run.end_date = None
    else:
        dag_run.end_date = timezone.utcnow()
    session.merge(dag_run)
Exemplo n.º 15
0
def load_chunk(session: Session, fetch_from: date, fetch_to: date, now: date, tan_callback):
    global error
    logging.info("Fetching TXs from %s to %s", fetch_from, fetch_to)
    error = False
    
    with hbci_client.get_account(tan_callback) as (conn, acc):
        def log_callback(_, response):
            if response.code[0] not in ('0', '1', '3'): # 0&1 info, 3 warning, rest err
                global error
                error = True

        conn.add_response_callback(log_callback)
        txs = conn.get_transactions(acc, fetch_from, fetch_to)

        while isinstance(txs, NeedTANResponse):
            logging.info("Calling tan callback for TXs %s", txs)
            tan = tan_callback(txs)
            if not tan:
                logging.error("No TAN got, aborting")
                return
            txs = conn.send_tan(txs, tan)

    logging.info("got txs: len: %s, type: %s, %s", len(txs), type(txs), txs)
    new_txs = []
    for tx in txs:
        tx = Transaction(tx)
        if tx.entry_date > now:
            logging.warning("Ignoring future tx: %s", tx)
            continue
        fuzz = timedelta(days=config.import_fuzz_grace_days)
        if tx.date + fuzz < fetch_from or tx.date - fuzz > fetch_to:
            logging.warning("Ignoring tx which is not in requested range: %s", tx)
            continue
        new_txs.append(tx)

    logging.info("Fetched %d txs", len(new_txs))
    if error:
        logging.error("Errors occurred")
        session.rollback()
        return False

    for tx in new_txs:
        session.merge(tx)
    return True
Exemplo n.º 16
0
def upgrade():
    session = Session(op.get_bind())

    for company in session.query(Company).all():

        datasource_config = DataSourceConfiguration(
            name=DATASOURCE_CONFIG_NAME,
            meta=DATASOURCE_META,
            company_id=company.id)

        session.add(datasource_config)
        for ds_entity in session.query(DataSource).filter(
                DataSource.company_id == company.id).all():
            ds_entity.datasource_configuration_id = datasource_config.id
            ds_entity.meta = datasource_config.meta
            session.merge(ds_entity)

    op.alter_column('data_source',
                    'datasource_configuration_id',
                    nullable=False)
Exemplo n.º 17
0
def verify_dagruns(
    dag_runs: Iterable[DagRun],
    commit: bool,
    state: DagRunState,
    session: SASession,
    current_task: BaseOperator,
):
    """Verifies integrity of dag_runs.

    :param dag_runs: dag runs to verify
    :param commit: whether dag runs state should be updated
    :param state: state of the dag_run to set if commit is True
    :param session: session to use
    :param current_task: current task
    :return:
    """
    for dag_run in dag_runs:
        dag_run.dag = current_task.subdag
        dag_run.verify_integrity()
        if commit:
            dag_run.state = state
            session.merge(dag_run)
Exemplo n.º 18
0
    def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None:
        """
        For the DAGs in the given DagBag, record any associated configuration warnings and clear
        warnings for files that no longer have them. These are usually displayed through the
        Airflow UI so that users know that there are issues parsing DAGs.

        :param session: session for ORM operations
        :param dagbag: DagBag containing DAGs with configuration warnings
        """
        self._validate_task_pools(dagbag=dagbag)

        stored_warnings = set(
            session.query(DagWarning).filter(
                DagWarning.dag_id.in_(dagbag.dags.keys())).all())

        for warning_to_delete in stored_warnings - self.dag_warnings:
            session.delete(warning_to_delete)

        for warning_to_add in self.dag_warnings:
            session.merge(warning_to_add)

        session.commit()
Exemplo n.º 19
0
def add_and_scan_channel(
    session: Session,
    channel: schema.Channel,
    autodownload: bool,
    dryrun: bool,
) -> None:
    channel.autodownload = autodownload
    session.merge(channel)

    try:
        videos = retrieve_videos(channel)
    except Exception as e:
        print(f"Couldn't retrieve videos: {e}", file=sys.stderr)
        return

    for v in videos:
        v.downloaded = 1
        session.merge(v)

    if not dryrun:
        session.commit()
    print(f'Subscribed to "{channel.name}"')
Exemplo n.º 20
0
def _set_dag_run_state(dag_id: str,
                       execution_date: datetime,
                       state: TaskInstanceState,
                       session: SASession = NEW_SESSION):
    """
    Helper method that set dag run state in the DB.

    :param dag_id: dag_id of target dag run
    :param execution_date: the execution date from which to start looking
    :param state: target state
    :param session: database session
    """
    dag_run = (session.query(DagRun).filter(
        DagRun.dag_id == dag_id,
        DagRun.execution_date == execution_date).one())
    dag_run.state = state
    if state == TaskInstanceState.RUNNING:
        dag_run.start_date = timezone.utcnow()
        dag_run.end_date = None
    else:
        dag_run.end_date = timezone.utcnow()
    session.merge(dag_run)
Exemplo n.º 21
0
Arquivo: add.py Projeto: jtpavlock/Moe
def _add_item(config: Config, session: Session, item_path: Path):
    """Adds a LibItem to the library from a given path.

    Args:
        config: Moe config.
        session: Current db session.
        item_path: Filesystem path of the item.

    Raises:
        AddError: Unable to add the item to the library.
    """
    item: LibItem
    if item_path.is_file():
        item = _add_track(item_path)
        old_album = item.album_obj
    elif item_path.is_dir():
        item = _add_album(item_path)
        old_album = item
    else:
        raise AddError(f"Path not found: {item_path}")

    old_album.merge(old_album.get_existing(session),
                    overwrite_album_info=False)
    new_albums = config.plugin_manager.hook.import_album(config=config,
                                                         session=session,
                                                         album=old_album)
    if new_albums:
        add_album = prompt.run_prompt(config, session, old_album,
                                      new_albums[0])
    else:
        add_album = old_album

    if add_album:
        config.plugin_manager.hook.pre_add(config=config,
                                           session=session,
                                           album=add_album)
        add_album = session.merge(add_album)
Exemplo n.º 22
0
    def expand_mapped_task(self, run_id: str, *,
                           session: Session) -> Sequence["TaskInstance"]:
        """Create the mapped task instances for mapped task.

        :return: The mapped task instances, in ascending order by map index.
        """
        from airflow.models.taskinstance import TaskInstance
        from airflow.settings import task_instance_mutation_hook

        total_length = functools.reduce(
            operator.mul,
            self._get_map_lengths(run_id, session=session).values())

        state: Optional[TaskInstanceState] = None
        unmapped_ti: Optional[TaskInstance] = (
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == self.dag_id,
                TaskInstance.task_id == self.task_id,
                TaskInstance.run_id == run_id,
                TaskInstance.map_index == -1,
                or_(TaskInstance.state.in_(State.unfinished),
                    TaskInstance.state.is_(None)),
            ).one_or_none())

        ret: List[TaskInstance] = []

        if unmapped_ti:
            # The unmapped task instance still exists and is unfinished, i.e. we
            # haven't tried to run it before.
            if total_length < 1:
                # If the upstream maps this to a zero-length value, simply marked the
                # unmapped task instance as SKIPPED (if needed).
                self.log.info(
                    "Marking %s as SKIPPED since the map has %d values to expand",
                    unmapped_ti,
                    total_length,
                )
                unmapped_ti.state = TaskInstanceState.SKIPPED
                session.flush()
                return ret
            # Otherwise convert this into the first mapped index, and create
            # TaskInstance for other indexes.
            unmapped_ti.map_index = 0
            state = unmapped_ti.state
            self.log.debug("Updated in place to become %s", unmapped_ti)
            ret.append(unmapped_ti)
            indexes_to_map = range(1, total_length)
        else:
            # Only create "missing" ones.
            current_max_mapping = (session.query(
                func.max(TaskInstance.map_index)).filter(
                    TaskInstance.dag_id == self.dag_id,
                    TaskInstance.task_id == self.task_id,
                    TaskInstance.run_id == run_id,
                ).scalar())
            indexes_to_map = range(current_max_mapping + 1, total_length)

        for index in indexes_to_map:
            # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
            # TODO: Change `TaskInstance` ctor to take Operator, not BaseOperator
            ti = TaskInstance(self,
                              run_id=run_id,
                              map_index=index,
                              state=state)  # type: ignore
            self.log.debug("Expanding TIs upserted %s", ti)
            task_instance_mutation_hook(ti)
            ti = session.merge(ti)
            ti.task = self
            ret.append(ti)

        # Set to "REMOVED" any (old) TaskInstances with map indices greater
        # than the current map value
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == self.dag_id,
            TaskInstance.task_id == self.task_id,
            TaskInstance.run_id == run_id,
            TaskInstance.map_index >= total_length,
        ).update({TaskInstance.state: TaskInstanceState.REMOVED})

        session.flush()

        return ret
Exemplo n.º 23
0
    def manage_slas(self, dag: DAG, session: Session = None) -> None:
        """
        Finding all tasks that have SLAs defined, and sending alert emails
        where needed. New SLA misses are also recorded in the database.

        We are assuming that the scheduler runs often, so we only check for
        tasks that should have succeeded in the past hour.
        """
        self.log.info("Running SLA Checks for %s", dag.dag_id)
        if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks):
            self.log.info(
                "Skipping SLA check for %s because no tasks in DAG have SLAs",
                dag)
            return

        qry = (session.query(
            TI.task_id,
            func.max(TI.execution_date).label('max_ti')).with_hint(
                TI, 'USE INDEX (PRIMARY)',
                dialect_name='mysql').filter(TI.dag_id == dag.dag_id).filter(
                    or_(TI.state == State.SUCCESS,
                        TI.state == State.SKIPPED)).filter(
                            TI.task_id.in_(dag.task_ids)).group_by(
                                TI.task_id).subquery('sq'))

        max_tis: List[TI] = (session.query(TI).filter(
            TI.dag_id == dag.dag_id,
            TI.task_id == qry.c.task_id,
            TI.execution_date == qry.c.max_ti,
        ).all())

        ts = timezone.utcnow()
        for ti in max_tis:
            task = dag.get_task(ti.task_id)
            if task.sla and not isinstance(task.sla, timedelta):
                raise TypeError(
                    f"SLA is expected to be timedelta object, got "
                    f"{type(task.sla)} in {task.dag_id}:{task.task_id}")

            dttm = dag.following_schedule(ti.execution_date)
            while dttm < timezone.utcnow():
                following_schedule = dag.following_schedule(dttm)
                if following_schedule + task.sla < timezone.utcnow():
                    session.merge(
                        SlaMiss(task_id=ti.task_id,
                                dag_id=ti.dag_id,
                                execution_date=dttm,
                                timestamp=ts))
                dttm = dag.following_schedule(dttm)
        session.commit()

        # pylint: disable=singleton-comparison
        slas: List[SlaMiss] = (
            session.query(SlaMiss).filter(SlaMiss.notification_sent == False,
                                          SlaMiss.dag_id == dag.dag_id)  # noqa
            .all())
        # pylint: enable=singleton-comparison

        if slas:  # pylint: disable=too-many-nested-blocks
            sla_dates: List[datetime.datetime] = [
                sla.execution_date for sla in slas
            ]
            fetched_tis: List[TI] = (session.query(TI).filter(
                TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates),
                TI.dag_id == dag.dag_id).all())
            blocking_tis: List[TI] = []
            for ti in fetched_tis:
                if ti.task_id in dag.task_ids:
                    ti.task = dag.get_task(ti.task_id)
                    blocking_tis.append(ti)
                else:
                    session.delete(ti)
                    session.commit()

            task_list = "\n".join(sla.task_id + ' on ' +
                                  sla.execution_date.isoformat()
                                  for sla in slas)
            blocking_task_list = "\n".join(ti.task_id + ' on ' +
                                           ti.execution_date.isoformat()
                                           for ti in blocking_tis)
            # Track whether email or any alert notification sent
            # We consider email or the alert callback as notifications
            email_sent = False
            notification_sent = False
            if dag.sla_miss_callback:
                # Execute the alert callback
                self.log.info('Calling SLA miss callback')
                try:
                    dag.sla_miss_callback(dag, task_list, blocking_task_list,
                                          slas, blocking_tis)
                    notification_sent = True
                except Exception:  # pylint: disable=broad-except
                    self.log.exception(
                        "Could not call sla_miss_callback for DAG %s",
                        dag.dag_id)
            email_content = f"""\
            Here's a list of tasks that missed their SLAs:
            <pre><code>{task_list}\n<code></pre>
            Blocking tasks:
            <pre><code>{blocking_task_list}<code></pre>
            Airflow Webserver URL: {conf.get(section='webserver', key='base_url')}
            """

            tasks_missed_sla = []
            for sla in slas:
                try:
                    task = dag.get_task(sla.task_id)
                except TaskNotFound:
                    # task already deleted from DAG, skip it
                    self.log.warning(
                        "Task %s doesn't exist in DAG anymore, skipping SLA miss notification.",
                        sla.task_id)
                    continue
                tasks_missed_sla.append(task)

            emails: Set[str] = set()
            for task in tasks_missed_sla:
                if task.email:
                    if isinstance(task.email, str):
                        emails |= set(get_email_address_list(task.email))
                    elif isinstance(task.email, (list, tuple)):
                        emails |= set(task.email)
            if emails:
                try:
                    send_email(emails,
                               f"[airflow] SLA miss on DAG={dag.dag_id}",
                               email_content)
                    email_sent = True
                    notification_sent = True
                except Exception:  # pylint: disable=broad-except
                    Stats.incr('sla_email_notification_failure')
                    self.log.exception(
                        "Could not send SLA Miss email notification for DAG %s",
                        dag.dag_id)
            # If we sent any notification, update the sla_miss table
            if notification_sent:
                for sla in slas:
                    sla.email_sent = email_sent
                    sla.notification_sent = True
                    session.merge(sla)
            session.commit()
Exemplo n.º 24
0
 def add(cls, session: Session, mpid: int) -> None:
     log.debug("Adding opt-out for MPID {}".format(mpid))
     newthing = cls(mpid=mpid)
     session.merge(newthing)
Exemplo n.º 25
0
    def update_state(
        self,
        session: Session = None,
        execute_callbacks: bool = True
    ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
            directly (default: true) or recorded as a pending request in the ``callback`` property
        :type execute_callbacks: bool
        :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
            needs to be executed
        """
        # Callback to execute in case of Task Failures
        callback: Optional[callback_requests.DagCallbackRequest] = None

        start_dttm = timezone.utcnow()
        self.last_scheduling_decision = start_dttm
        with Stats.timer(f"dagrun.dependency-check.{self.dag_id}"):
            dag = self.get_dag()
            info = self.task_instance_scheduling_decisions(session)

            tis = info.tis
            schedulable_tis = info.schedulable_tis
            changed_tis = info.changed_tis
            finished_tasks = info.finished_tasks
            unfinished_tasks = info.unfinished_tasks

            none_depends_on_past = all(not t.task.depends_on_past
                                       for t in unfinished_tasks)
            none_task_concurrency = all(t.task.max_active_tis_per_dag is None
                                        for t in unfinished_tasks)
            none_deferred = all(t.state != State.DEFERRED
                                for t in unfinished_tasks)

            if unfinished_tasks and none_depends_on_past and none_task_concurrency and none_deferred:
                # small speed up
                are_runnable_tasks = (schedulable_tis
                                      or self._are_premature_tis(
                                          unfinished_tasks, finished_tasks,
                                          session) or changed_tis)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tasks and any(leaf_ti.state in State.failed_states
                                        for leaf_ti in leaf_tis):
            self.log.error('Marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='task_failure',
                                    session=session)
            elif dag.has_on_failure_callback:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='task_failure',
                )

        # if all leaves succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tasks and all(leaf_ti.state in State.success_states
                                          for leaf_ti in leaf_tis):
            self.log.info('Marking run %s successful', self)
            self.set_state(State.SUCCESS)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=True,
                                    reason='success',
                                    session=session)
            elif dag.has_on_success_callback:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=False,
                    msg='success',
                )

        # if *all tasks* are deadlocked, the run failed
        elif (unfinished_tasks and none_depends_on_past
              and none_task_concurrency and none_deferred
              and not are_runnable_tasks):
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='all_tasks_deadlocked',
                                    session=session)
            elif dag.has_on_failure_callback:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='all_tasks_deadlocked',
                )

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(State.RUNNING)

        self._emit_true_scheduling_delay_stats_for_finished_state(
            finished_tasks)
        self._emit_duration_stats_for_finished_state()

        session.merge(self)

        return schedulable_tis, callback
Exemplo n.º 26
0
    def expand_mapped_task(
            self, run_id: str, *,
            session: Session) -> Tuple[Sequence["TaskInstance"], int]:
        """Create the mapped task instances for mapped task.

        :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the
            maximum map_index.
        """
        from airflow.models.taskinstance import TaskInstance
        from airflow.settings import task_instance_mutation_hook

        total_length: Optional[int]
        try:
            total_length = self._get_specified_expand_input(
            ).get_total_map_length(run_id, session=session)
        except NotFullyPopulated as e:
            self.log.info(
                "Cannot expand %r for run %s; missing upstream values: %s",
                self,
                run_id,
                sorted(e.missing),
            )
            total_length = None

        state: Optional[TaskInstanceState] = None
        unmapped_ti: Optional[TaskInstance] = (
            session.query(TaskInstance).filter(
                TaskInstance.dag_id == self.dag_id,
                TaskInstance.task_id == self.task_id,
                TaskInstance.run_id == run_id,
                TaskInstance.map_index == -1,
                or_(TaskInstance.state.in_(State.unfinished),
                    TaskInstance.state.is_(None)),
            ).one_or_none())

        all_expanded_tis: List[TaskInstance] = []

        if unmapped_ti:
            # The unmapped task instance still exists and is unfinished, i.e. we
            # haven't tried to run it before.
            if total_length is None:
                # If the map length cannot be calculated (due to unavailable
                # upstream sources), fail the unmapped task.
                unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
                indexes_to_map: Iterable[int] = ()
            elif total_length < 1:
                # If the upstream maps this to a zero-length value, simply mark
                # the unmapped task instance as SKIPPED (if needed).
                self.log.info(
                    "Marking %s as SKIPPED since the map has %d values to expand",
                    unmapped_ti,
                    total_length,
                )
                unmapped_ti.state = TaskInstanceState.SKIPPED
                indexes_to_map = ()
            else:
                # Otherwise convert this into the first mapped index, and create
                # TaskInstance for other indexes.
                unmapped_ti.map_index = 0
                self.log.debug("Updated in place to become %s", unmapped_ti)
                all_expanded_tis.append(unmapped_ti)
                indexes_to_map = range(1, total_length)
            state = unmapped_ti.state
        elif not total_length:
            # Nothing to fixup.
            indexes_to_map = ()
        else:
            # Only create "missing" ones.
            current_max_mapping = (session.query(
                func.max(TaskInstance.map_index)).filter(
                    TaskInstance.dag_id == self.dag_id,
                    TaskInstance.task_id == self.task_id,
                    TaskInstance.run_id == run_id,
                ).scalar())
            indexes_to_map = range(current_max_mapping + 1, total_length)

        for index in indexes_to_map:
            # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
            ti = TaskInstance(self,
                              run_id=run_id,
                              map_index=index,
                              state=state)
            self.log.debug("Expanding TIs upserted %s", ti)
            task_instance_mutation_hook(ti)
            ti = session.merge(ti)
            ti.refresh_from_task(
                self)  # session.merge() loses task information.
            all_expanded_tis.append(ti)

        # Coerce the None case to 0 -- these two are almost treated identically,
        # except the unmapped ti (if exists) is marked to different states.
        total_expanded_ti_count = total_length or 0

        # Set to "REMOVED" any (old) TaskInstances with map indices greater
        # than the current map value
        session.query(TaskInstance).filter(
            TaskInstance.dag_id == self.dag_id,
            TaskInstance.task_id == self.task_id,
            TaskInstance.run_id == run_id,
            TaskInstance.map_index >= total_expanded_ti_count,
        ).update({TaskInstance.state: TaskInstanceState.REMOVED})

        session.flush()
        return all_expanded_tis, total_expanded_ti_count - 1
Exemplo n.º 27
0
def upgrade():
    session = Session(bind=op.get_bind(), expire_on_commit=False)

    # All IRONMAN orgs need a 3 digit version
    IRONMAN_system = 'http://pcctc.org/'

    ironman_org_ids = [(id.id, id._value) for id in Identifier.query.filter(
        Identifier.system == IRONMAN_system).with_entities(
        Identifier.id, Identifier._value)]
    existing_values = [id[1] for id in ironman_org_ids]

    replacements = {}
    for io_id, io_value in ironman_org_ids:
        found = org_pattern.match(io_value)
        if found:
            # avoid probs if run again - don't add if already present
            needed = '146-0{}'.format(found.group(1))
            replacements[found.group(1)] = '0{}'.format(found.group(1))
            if needed not in existing_values:
                needed_i = Identifier(
                    use='secondary', system=IRONMAN_system, _value=needed)
            else:
                needed_i = Identifier.query.filter(
                    Identifier.system == IRONMAN_system).filter(
                    Identifier._value == needed).one()

            # add a 3 digit identifier and link with same org
            oi = OrganizationIdentifier.query.filter(
                OrganizationIdentifier.identifier_id == io_id).one()
            needed_oi = OrganizationIdentifier.query.filter(
                OrganizationIdentifier.organization_id ==
                oi.organization_id).filter(
                OrganizationIdentifier.identifier == needed_i).first()
            if not needed_oi:
                needed_i = session.merge(needed_i)
                needed_oi = OrganizationIdentifier(
                    organization_id=oi.organization_id,
                    identifier=needed_i)
                session.add(needed_oi)

    # All IRONMAN users with a 2 digit ID referencing one of the replaced
    # values needs a 3 digit version

    ironman_study_ids = Identifier.query.filter(
        Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter(
        Identifier._value.like('170-%')).with_entities(
        Identifier.id, Identifier._value)

    for iid, ival in ironman_study_ids:
        found = study_pattern.match(ival)
        if found:
            org_segment = found.group(1)
            patient_segment = found.group(2)

            # only add if also one of the new org ids
            if org_segment not in replacements:
                continue

            needed = '170-{}-{}'.format(
                replacements[org_segment], patient_segment)

            # add a 3 digit identifier and link with same user(s),
            # if not already present
            uis = UserIdentifier.query.filter(
                UserIdentifier.identifier_id == iid)
            needed_i = Identifier.query.filter(
                Identifier.system == TRUENTH_EXTERNAL_STUDY_SYSTEM).filter(
                Identifier._value == needed).first()
            if not needed_i:
                needed_i = Identifier(
                    use='secondary', system=TRUENTH_EXTERNAL_STUDY_SYSTEM,
                    _value=needed)
            for ui in uis:
                needed_ui = UserIdentifier.query.filter(
                    UserIdentifier.user_id == ui.user_id).filter(
                    UserIdentifier.identifier == needed_i).first()
                if not needed_ui:
                    needed_ui = UserIdentifier(
                        user_id=ui.user_id, identifier=needed_i)
                    session.add(needed_ui)

    session.commit()
Exemplo n.º 28
0
 def merge(self, entity, load=True):
     self.begin()
     SaSession.merge(self, entity, load=load)
     self.commit()
Exemplo n.º 29
0
class CrawlProcessor(object):

    __VERSION__ = "CrawlProcessor-0.2.1"

    def __init__(self, engine, redis_server, stop_list="keyword_filter.txt"):

        if type(engine) == types.StringType:
            logging.info("Using connection string '%s'" % (engine,))
            new_engine = create_engine(engine, encoding='utf-8', isolation_level="READ COMMITTED")
            if "sqlite:" in engine:
                logging.debug("Setting text factory for unicode compat.")
                new_engine.raw_connection().connection.text_factory = str 
            self._engine = new_engine
        else:
            logging.info("Using existing engine...")
            self._engine = engine
        logging.info("Binding session...")
        self._session = Session(bind=self._engine, autocommit = False)

        if type(stop_list) == types.StringType:
            stop_list_fp = open(stop_list)
        else:
            stop_list_fp = stop_list 

        self.stop_list = set([])
        for line in stop_list_fp:
            self.stop_list.add(line.strip())

        self.cls = DocumentClassifier()
        self.dc  = DomainController(self._engine, self._session)
        self.ac  = ArticleController(self._engine, self._session)
        self.ex  = extract.TermExtractor()
        self.kwc = KeywordController(self._engine, self._session)
        self.swc = SoftwareVersionsController(self._engine, self._session)
        self.redis_kw = redis.Redis(host=redis_server, port=6379, db=1)
        self.redis_dm = redis.Redis(host=redis_server, port=6379, db=2)
        dm_session = Session(bind=self._engine, autocommit = False)
        self.drw = DomainResolutionWorker(dm_session, self.redis_dm)


    def _check_processed(self, item):
        crawl_id, record = item 
        headers, content, url, date_crawled, content_type = record

        path   = self.ac.get_path_fromurl(url)
        domain_identifier = None 
        logging.info("_check_processed: retrieving domain...")
        domain_key = self.dc.get_Domain_key(url)
        while domain_identifier == None:
            domain_identifier = self.drw.get_domain(domain_key)

        it = self._session.query(Article).filter_by(crawl_id = crawl_id).filter_by(domain_id = domain_identifier).filter_by(path = path)
        try:
            it = it.one()
            logging.error("%s: already processed", url)
            return False 
        except sqlalchemy.orm.exc.MultipleResultsFound:
            logging.error("%s: appears to have been already processed multiple times", url)
            return False 
        except sqlalchemy.orm.exc.NoResultFound:
            logging.info("%s: hasn't been processed yet", url)
            return True 

    def process_record(self, item):
        if len(item) != 2:
            raise ValueError(item)
        if not self._check_processed(item):
            return None
        ret, retries = None, 2
        while ret == None and retries > 0:
            try:
                retries -= 1
                ret = self._process_record(item)
            except Exception as ex:
                import traceback
                print >> sys.stderr, ex
                traceback.print_exc()
                raise ex 
        if ret == False:
            return None 
        return ret 


    def _process_record(self, item_arg):

        crawl_id, record = item_arg
        headers, content, url, date_crawled, content_type = record

        assert headers is not None
        assert content is not None 
        assert url is not None 
        assert date_crawled is not None 
        assert content_type is not None 

        status = "Processed"

        # Fix for a seg-fault
        if "nasa.gov" in url:
            return False

        # Sort out the domain
        domain_identifier = None 
        logging.info("Retrieving domain...")
        domain_key = self.dc.get_Domain_key(url)
        while domain_identifier == None:
            domain_identifier = self.drw.get_domain(domain_key)

        domain = self._session.query(Domain).get(domain_identifier)
        assert domain is not None

        # Build database objects 
        path   = self.ac.get_path_fromurl(url)
        article = Article(path, date_crawled, crawl_id, domain, status)
        self._session.add(article)
        classified_by = self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__)
        assert classified_by is not None

        if content_type != 'text/html':
            logging.error("Unsupported content type: %s", str(content_type))
            article.status = "UnsupportedType"
            return False

        # Start the async transaction to get the plain text
        worker_req_thread = BoilerPipeWorker(content)
        worker_req_thread.start()

        # Whilst that's executing, parse the document 
        logging.info("Parsing HTML...")
        html = BeautifulSoup(content)

        if html is None or html.body is None:
            article.status = "NoContent"
            return False

        # Extract the dates 
        date_dict = pydate.get_dates(html)

        if len(date_dict) == 0:
            status = "NoDates"

        # Detect the language
        lang, lang_certainty = langid.classify(content)

        # Wait for the BoilerPipe thread to complete
        worker_req_thread.join()
        logging.debug(worker_req_thread.result)
        logging.debug(worker_req_thread.version)

        if worker_req_thread.result == None:
            article.status = "NoContent"
            return False

        # If the language isn't English, skip it
        if lang != "en":
            logging.info("language: %s with certainty %.2f - skipping...", lang, lang_certainty)
            article.status = "LanguageError" # Replace with something appropriate
            return False

        content = worker_req_thread.result.encode('ascii', 'ignore')

        # Headline extraction 
        h_counter = 6
        headline = None
        while h_counter > 0:
            tag = "h%d" % (h_counter,)
            found = False 
            for node in html.findAll(tag):
                if node.text in content:
                    headline = node.text 
                    found = True 
                    break 
            if found:
                break
            h_counter -= 1

        # Run keyword extraction 
        keywords = self.ex(content)
        kset     = KeywordSet(self.stop_list)
        nnp_sets_scored = set([])

        for word, freq, amnt in sorted(keywords):
            try:
                nnp_sets_scored.add((word, freq))
            except ValueError:
                break 

        nnp_adj = set([])
        nnp_set = set([])
        nnp_vector = []
        for sentence in sent_tokenize(content):
            text = nltk.word_tokenize(sentence)
            pos  = nltk.pos_tag(text)
            pos_groups = itertools.groupby(pos, lambda x: x[1])
            for k, g in pos_groups:
                if k != 'NNP':
                    continue
                nnp_list = [word for word, speech in g]
                nnp_buf = []
                for item in nnp_list:
                    nnp_set.add(item)
                    nnp_buf.append(item)
                    nnp_vector.append(item)
                for i, j in zip(nnp_buf[0:-1], nnp_buf[1:]):
                    nnp_adj.add((i, j))

        nnp_vector = filter(lambda x: x.lower() not in self.stop_list, nnp_vector)
        nnp_counter = Counter(nnp_vector)
        for word in nnp_set:
            score = nnp_counter[word]
            nnp_sets_scored.add((item, score))

        for item, score in sorted(nnp_sets_scored, key=lambda x: x[1], reverse=True):
            try: 
                if type(item) == types.ListType or type(item) == types.TupleType:
                    kset.add(' '.join(item))
                else:
                    kset.add(item)
            except ValueError:
                break 

        scored_nnp_adj = []
        for item1, item2 in nnp_adj:
            score = nnp_counter[item1] + nnp_counter[item2]
            scored_nnp_adj.append((item1, item2, score))

        nnp_adj = []
        for item1, item2, score in sorted(scored_nnp_adj, key=lambda x: x[1], reverse=True):
            if len(nnp_adj) < KEYWORD_LIMIT:
                nnp_adj.append((item1, item2))
            else:
                break

        # Generate list of all keywords
        keywords = set([])
        for keyword in kset:
            try:
                k = Keyword(keyword)
                keywords.add(k)
            except ValueError as ex:
                logging.error(ex)
                continue
        for item1, item2 in nnp_adj:
            try:
                k = Keyword(item1)
                keywords.add(k)
            except ValueError as ex:
                logging.error(ex)
            try:
                k = Keyword(item2)
                keywords.add(k)
            except ValueError as ex:
                logging.error(ex)

        # Resolve keyword identifiers
        keyword_resolution_worker = KeywordResolutionWorker(set([k.word for k in keywords]), self.redis_kw)
        keyword_resolution_worker.start()
            
        # Run sentiment analysis
        trace = []
        features = self.cls.classify(worker_req_thread.result, trace) 
        label, length, classified, pos_sentences, neg_sentences,\
        pos_phrases, neg_phrases  = features[0:7]        

        # Convert Pysen's model into database models
        try:
            doc = Document(article.id, label, length, pos_sentences, neg_sentences, pos_phrases, neg_phrases, headline)
        except ValueError as ex:
            logging.error(ex)
            logging.error("Skipping this document...")
            article.status = "ClassificationError"
            return False

        self._session.add(doc)
        extracted_phrases = set([])
        for sentence, score, phrase_trace in trace:
            sentence_type = "Unknown"
            for node in html.findAll(text=True):
                if sentence.text in node.strip():
                    sentence_type = node.parent.name.upper()
                    break

            if sentence_type not in ["H1", "H2", "H3", "H4", "H5", "H6", "P", "Unknown"]:
                sentence_type = "Other"

            label, average, prob, pos, neg, probs, _scores = score 

            s = Sentence(doc, label, average, prob, sentence_type)
            self._session.add(s)
            for phrase, prob, score, label in phrase_trace:
                p = Phrase(s, score, prob, label)
                self._session.add(p)
                extracted_phrases.add((phrase, p))

        # Wait for keyword resolution to finish
        keyword_resolution_worker.join()
        keyword_mapping = keyword_resolution_worker.out_keywords

        # Associate extracted keywords with phrases
        keyword_objects, short_keywords = kset.convert(keyword_mapping, self.kwc)
        for k in keyword_objects:
            self._session.merge(k)
        for p, p_obj in extracted_phrases:
            for k in keyword_objects:
                if k.word in p.get_text():
                    nk = KeywordIncidence(k, p_obj)

        # Save the keyword adjacency list
        for i, j in kset.convert_adj_tuples(nnp_adj, keyword_mapping, self.kwc):
            self._session.merge(i)
            self._session.merge(j)
            kwa = KeywordAdjacency(i, j, doc)
            self._session.add(kwa)

        # Build date objects
        for key in date_dict:
            rec  = date_dict[key]
            if "dates" not in rec:
                logging.error("OK: 'dates' is not in a pydate result record.")
                continue
            dlen = len(rec["dates"])
            if rec["text"] not in content:
                logging.debug("'%s' is not in %s", rec["text"], content)
                continue
            if dlen > 1:
                for date, day_first, year_first in rec["dates"]:
                    try:
                        dobj = AmbiguousDate(date, doc, day_first, year_first, rec["prep"], key)
                    except ValueError as ex:
                        logging.error(ex)
                        continue
                    self._session.add(dobj)
            elif dlen == 1:
                for date, day_first, year_first in rec["dates"]:
                    dobj = CertainDate(date, doc, key)
                    self._session.add(dobj)
            else:
                logging.error("'dates' in a pydate result set contains no records.")

        # Process links
        for link in html.findAll('a'):
            if not link.has_attr("href"):
                logging.debug("skipping %s: no href", link)
                continue

            process = True 
            for node in link.findAll(text=True):
                if node not in worker_req_thread.result:
                    process = False 
                    break 
            
            if not process:
                logging.debug("skipping %s because it's not in the body text", link)
                break

            href, junk, junk = link["href"].partition("#")
            if "http://" in href:
                try:

                    domain_id = None 
                    domain_key = self.dc.get_Domain_key(href)
                    while domain_id is None:
                        domain_id = self.drw.get_domain(domain_key)

                    assert domain_id is not None
                    href_domain = self._session.query(Domain).get(domain_id)
                except ValueError as ex:
                    logging.error(ex)
                    logging.error("Skipping this link")
                    continue
                href_path   = self.ac.get_path_fromurl(href)
                lnk = AbsoluteLink(doc, href_domain, href_path)
                self._session.add(lnk)
                logging.debug("Adding: %s", lnk)
            else:
                href_path  = href 
                try:
                    lnk = RelativeLink(doc, href_path)
                except ValueError as ex:
                    logging.error(ex)
                    logging.error("Skipping link")
                    continue
                self._session.add(lnk)
                logging.debug("Adding: %s", lnk)

        # Construct software involvment records
        self_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(self.__VERSION__), "Processed", doc)
        date_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pydate.__VERSION__), "Dated", doc)
        clas_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(pysen.__VERSION__), "Classified", doc)
        extr_sir = SoftwareInvolvementRecord(self.swc.get_SoftwareVersion_fromstr(worker_req_thread.version), "Extracted", doc)

        for sw in [self_sir, date_sir, clas_sir, extr_sir]:
            self._session.merge(sw, load=True)

        logging.debug("Domain: %s", domain)
        logging.debug("Path: %s", path)
        article.status = status

        # Commit to database, return True on success
        try:
            self._session.commit()
        except OperationalError as ex:
            logging.error(ex)
            self._session.rollback()
            return None

        return article.id

    def finalize(self):
        self._session.commit()
Exemplo n.º 30
0
    def verify_integrity(self, session: Session = NEW_SESSION):
        """
        Verifies the DagRun by checking for removed tasks or tasks that are not in the
        database yet. It will set state to removed or add the task if required.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        """
        from airflow.settings import task_instance_mutation_hook

        dag = self.get_dag()
        tis = self.get_task_instances(session=session)

        # check for removed or restored tasks
        task_ids = set()
        for ti in tis:
            task_instance_mutation_hook(ti)
            task_ids.add(ti.task_id)
            task = None
            try:
                task = dag.get_task(ti.task_id)
            except AirflowException:
                if ti.state == State.REMOVED:
                    pass  # ti has already been removed, just ignore it
                elif self.state != State.RUNNING and not dag.partial:
                    self.log.warning(
                        "Failed to get task '%s' for dag '%s'. Marking it as removed.",
                        ti, dag)
                    Stats.incr(f"task_removed_from_dag.{dag.dag_id}", 1, 1)
                    ti.state = State.REMOVED

            should_restore_task = (task
                                   is not None) and ti.state == State.REMOVED
            if should_restore_task:
                self.log.info(
                    "Restoring task '%s' which was previously removed from DAG '%s'",
                    ti, dag)
                Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
                ti.state = State.NONE
            session.merge(ti)

        def task_filter(task: "BaseOperator"):
            return task.task_id not in task_ids and (
                self.is_backfill or task.start_date <= self.execution_date)

        created_counts: Dict[str, int] = defaultdict(int)

        # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
        hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)

        if hook_is_noop:

            def create_ti_mapping(task: "BaseOperator"):
                created_counts[task.task_type] += 1
                return TI.insert_mapping(self.run_id, task)

        else:

            def create_ti(task: "BaseOperator") -> TI:
                ti = TI(task, run_id=self.run_id)
                task_instance_mutation_hook(ti)
                created_counts[ti.operator] += 1
                return ti

        # Create missing tasks
        tasks = list(filter(task_filter, dag.task_dict.values()))
        try:
            if hook_is_noop:
                session.bulk_insert_mappings(TI, map(create_ti_mapping, tasks))
            else:
                session.bulk_save_objects(map(create_ti, tasks))

            for task_type, count in created_counts.items():
                Stats.incr(f"task_instance_created-{task_type}", count)
            session.flush()
        except IntegrityError as err:
            self.log.info(str(err))
            self.log.info(
                'Hit IntegrityError while creating the TIs for %s- %s',
                dag.dag_id, self.run_id)
            self.log.info('Doing session rollback.')
            # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
            session.rollback()
Exemplo n.º 31
0
 def merge(self, entity, load=True):
     self.begin()
     SaSession.merge(self, entity, load=load)
     self.commit()
Exemplo n.º 32
0
    def update_state(
        self,
        session: Session = None,
        execute_callbacks: bool = True
    ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
            directly (default: true) or recorded as a pending request in the ``callback`` property
        :type execute_callbacks: bool
        :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
            needs to be executed
        """
        # Callback to execute in case of Task Failures
        callback: Optional[callback_requests.DagCallbackRequest] = None

        start_dttm = timezone.utcnow()
        self.last_scheduling_decision = start_dttm

        dag = self.get_dag()
        ready_tis: List[TI] = []
        tis = list(
            self.get_task_instances(session=session,
                                    state=State.task_states +
                                    (State.SHUTDOWN, )))
        self.log.debug("number of tis tasks for %s: %s task(s)", self,
                       len(tis))
        for ti in tis:
            ti.task = dag.get_task(ti.task_id)

        unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
        finished_tasks = [
            t for t in tis
            if t.state in State.finished() + [State.UPSTREAM_FAILED]
        ]
        none_depends_on_past = all(not t.task.depends_on_past
                                   for t in unfinished_tasks)
        none_task_concurrency = all(t.task.task_concurrency is None
                                    for t in unfinished_tasks)
        if unfinished_tasks:
            scheduleable_tasks = [
                ut for ut in unfinished_tasks
                if ut.state in SCHEDULEABLE_STATES
            ]
            self.log.debug("number of scheduleable tasks for %s: %s task(s)",
                           self, len(scheduleable_tasks))
            ready_tis, changed_tis = self._get_ready_tis(
                scheduleable_tasks, finished_tasks, session)
            self.log.debug("ready tis length for %s: %s task(s)", self,
                           len(ready_tis))
            if none_depends_on_past and none_task_concurrency:
                # small speed up
                are_runnable_tasks = ready_tis or self._are_premature_tis(
                    unfinished_tasks, finished_tasks, session) or changed_tis

        duration = (timezone.utcnow() - start_dttm)
        Stats.timing("dagrun.dependency-check.{}".format(self.dag_id),
                     duration)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tasks and any(
                leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED}
                for leaf_ti in leaf_tis):
            self.log.error('Marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='task_failure',
                                    session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='task_failure')

        # if all leafs succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tasks and all(
                leaf_ti.state in {State.SUCCESS, State.SKIPPED}
                for leaf_ti in leaf_tis):
            self.log.info('Marking run %s successful', self)
            self.set_state(State.SUCCESS)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=True,
                                    reason='success',
                                    session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=False,
                    msg='success')

        # if *all tasks* are deadlocked, the run failed
        elif (unfinished_tasks and none_depends_on_past
              and none_task_concurrency and not are_runnable_tasks):
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(State.FAILED)
            if execute_callbacks:
                dag.handle_callback(self,
                                    success=False,
                                    reason='all_tasks_deadlocked',
                                    session=session)
            else:
                callback = callback_requests.DagCallbackRequest(
                    full_filepath=dag.fileloc,
                    dag_id=self.dag_id,
                    execution_date=self.execution_date,
                    is_failure_callback=True,
                    msg='all_tasks_deadlocked')

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(State.RUNNING)

        self._emit_duration_stats_for_finished_state()

        session.merge(self)

        return ready_tis, callback
Exemplo n.º 33
0
    def update_state(self, session: Session = None) -> List[TI]:
        """
        Determines the overall state of the DagRun based on the state
        of its TaskInstances.

        :param session: Sqlalchemy ORM Session
        :type session: Session
        :return: ready_tis: the tis that can be scheduled in the current loop
        :rtype ready_tis: list[airflow.models.TaskInstance]
        """

        dag = self.get_dag()
        ready_tis: List[TI] = []
        tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,)))
        self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
        for ti in tis:
            ti.task = dag.get_task(ti.task_id)

        start_dttm = timezone.utcnow()
        unfinished_tasks = [t for t in tis if t.state in State.unfinished()]
        finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]]
        none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
        none_task_concurrency = all(t.task.task_concurrency is None
                                    for t in unfinished_tasks)
        if unfinished_tasks:
            scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES]
            self.log.debug(
                "number of scheduleable tasks for %s: %s task(s)",
                self, len(scheduleable_tasks))
            ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session)
            self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis))
            if none_depends_on_past and none_task_concurrency:
                # small speed up
                are_runnable_tasks = ready_tis or self._are_premature_tis(
                    unfinished_tasks, finished_tasks, session) or changed_tis

        duration = (timezone.utcnow() - start_dttm)
        Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration)

        leaf_task_ids = {t.task_id for t in dag.leaves}
        leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids]

        # if all roots finished and at least one failed, the run failed
        if not unfinished_tasks and any(
            leaf_ti.state in {State.FAILED, State.UPSTREAM_FAILED} for leaf_ti in leaf_tis
        ):
            self.log.error('Marking run %s failed', self)
            self.set_state(State.FAILED)
            dag.handle_callback(self, success=False, reason='task_failure', session=session)

        # if all leafs succeeded and no unfinished tasks, the run succeeded
        elif not unfinished_tasks and all(
            leaf_ti.state in {State.SUCCESS, State.SKIPPED} for leaf_ti in leaf_tis
        ):
            self.log.info('Marking run %s successful', self)
            self.set_state(State.SUCCESS)
            dag.handle_callback(self, success=True, reason='success', session=session)

        # if *all tasks* are deadlocked, the run failed
        elif (unfinished_tasks and none_depends_on_past and
              none_task_concurrency and not are_runnable_tasks):
            self.log.error('Deadlock; marking run %s failed', self)
            self.set_state(State.FAILED)
            dag.handle_callback(self, success=False, reason='all_tasks_deadlocked',
                                session=session)

        # finally, if the roots aren't done, the dag is still running
        else:
            self.set_state(State.RUNNING)

        self._emit_duration_stats_for_finished_state()

        # todo: determine we want to use with_for_update to make sure to lock the run
        session.merge(self)
        session.commit()

        return ready_tis
Exemplo n.º 34
0
class InboxSession(object):
    """ Inbox custom ORM (with SQLAlchemy compatible API).

    Parameters
    ----------
    engine : <sqlalchemy.engine.Engine>
        A configured database engine to use for this session
    versioned : bool
        Do you want to enable the transaction log?
    ignore_soft_deletes : bool
        Whether or not to ignore soft-deleted objects in query results.
    namespace_id : int
        Namespace to limit query results with.
    """
    def __init__(self, engine, versioned=True, ignore_soft_deletes=True,
                 namespace_id=None):
        # TODO: support limiting on namespaces
        assert engine, "Must set the database engine"

        args = dict(bind=engine, autoflush=True, autocommit=False)
        self.ignore_soft_deletes = ignore_soft_deletes
        if ignore_soft_deletes:
            args['query_cls'] = InboxQuery
        self._session = Session(**args)

        if versioned:
            from inbox.models.transaction import create_revisions

            @event.listens_for(self._session, 'after_flush')
            def after_flush(session, flush_context):
                """
                Hook to log revision snapshots. Must be post-flush in order to
                grab object IDs on new objects.
                """
                create_revisions(session)

    def query(self, *args, **kwargs):
        q = self._session.query(*args, **kwargs)
        if self.ignore_soft_deletes:
            return q.options(IgnoreSoftDeletesOption())
        else:
            return q

    def add(self, instance):
        if not self.ignore_soft_deletes or not instance.is_deleted:
            self._session.add(instance)
        else:
            raise Exception("Why are you adding a deleted object?")

    def add_all(self, instances):
        if True not in [i.is_deleted for i in instances] or \
                not self.ignore_soft_deletes:
            self._session.add_all(instances)
        else:
            raise Exception("Why are you adding a deleted object?")

    def delete(self, instance):
        if self.ignore_soft_deletes:
            instance.mark_deleted()
            # just to make sure
            self._session.add(instance)
        else:
            self._session.delete(instance)

    def begin(self):
        self._session.begin()

    def commit(self):
        self._session.commit()

    def rollback(self):
        self._session.rollback()

    def flush(self):
        self._session.flush()

    def close(self):
        self._session.close()

    def expunge(self, obj):
        self._session.expunge(obj)

    def merge(self, obj):
        return self._session.merge(obj)

    @property
    def no_autoflush(self):
        return self._session.no_autoflush