Exemplo n.º 1
0
class TaskFail(Base):
    """TaskFail tracks the failed run durations of each task instance."""

    __tablename__ = "task_fail"

    id = Column(Integer, primary_key=True)
    task_id = Column(StringID(), nullable=False)
    dag_id = Column(StringID(), nullable=False)
    run_id = Column(StringID(), nullable=False)
    map_index = Column(Integer, nullable=False)
    start_date = Column(UtcDateTime)
    end_date = Column(UtcDateTime)
    duration = Column(Integer)

    __table_args__ = (
        ForeignKeyConstraint(
            [dag_id, task_id, run_id, map_index],
            [
                "task_instance.dag_id",
                "task_instance.task_id",
                "task_instance.run_id",
                "task_instance.map_index",
            ],
            name='task_fail_ti_fkey',
            ondelete="CASCADE",
        ),
    )

    # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining
    # the relationship we can more easily find the execution date for these rows
    dag_run = relationship(
        "DagRun",
        primaryjoin="""and_(
            TaskFail.dag_id == foreign(DagRun.dag_id),
            TaskFail.run_id == foreign(DagRun.run_id),
        )""",
        viewonly=True,
    )

    def __init__(self, ti):
        self.dag_id = ti.dag_id
        self.task_id = ti.task_id
        self.run_id = ti.run_id
        self.map_index = ti.map_index
        self.start_date = ti.start_date
        self.end_date = ti.end_date
        if self.end_date and self.start_date:
            self.duration = int((self.end_date - self.start_date).total_seconds())
        else:
            self.duration = None

    def __repr__(self):
        prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
        if self.map_index != -1:
            prefix += f" map_index={self.map_index}"
        return prefix + '>'
def upgrade():
    """
    Add ``map_index`` column to TaskInstance to identify task-mapping,
    and a ``task_map`` table to track mapping values from XCom.
    """
    # We need to first remove constraints on task_reschedule since they depend on task_instance.
    with op.batch_alter_table("task_reschedule") as batch_op:
        batch_op.drop_constraint("task_reschedule_ti_fkey", "foreignkey")
        batch_op.drop_index("idx_task_reschedule_dag_task_run")

    # Change task_instance's primary key.
    with op.batch_alter_table("task_instance") as batch_op:
        # I think we always use this name for TaskInstance after 7b2661a43ba3?
        batch_op.drop_constraint("task_instance_pkey", type_="primary")
        batch_op.add_column(Column("map_index", Integer, nullable=False, server_default=text("-1")))
        batch_op.create_primary_key("task_instance_pkey", ["dag_id", "task_id", "run_id", "map_index"])

    # Re-create task_reschedule's constraints.
    with op.batch_alter_table("task_reschedule") as batch_op:
        batch_op.add_column(Column("map_index", Integer, nullable=False, server_default=text("-1")))
        batch_op.create_foreign_key(
            "task_reschedule_ti_fkey",
            "task_instance",
            ["dag_id", "task_id", "run_id", "map_index"],
            ["dag_id", "task_id", "run_id", "map_index"],
            ondelete="CASCADE",
        )
        batch_op.create_index(
            "idx_task_reschedule_dag_task_run",
            ["dag_id", "task_id", "run_id", "map_index"],
            unique=False,
        )

    # Create task_map.
    op.create_table(
        "task_map",
        Column("dag_id", StringID(), primary_key=True),
        Column("task_id", StringID(), primary_key=True),
        Column("run_id", StringID(), primary_key=True),
        Column("map_index", Integer, primary_key=True),
        Column("length", Integer, nullable=False),
        Column("keys", ExtendedJSON, nullable=True),
        CheckConstraint("length >= 0", name="task_map_length_not_negative"),
        ForeignKeyConstraint(
            ["dag_id", "task_id", "run_id", "map_index"],
            [
                "task_instance.dag_id",
                "task_instance.task_id",
                "task_instance.run_id",
                "task_instance.map_index",
            ],
            name="task_map_task_instance_fkey",
            ondelete="CASCADE",
        ),
    )
def upgrade():
    """Add TaskMap and map_index on TaskInstance."""
    # We need to first remove constraints on task_reschedule since they depend on task_instance.
    with op.batch_alter_table("task_reschedule") as batch_op:
        batch_op.drop_constraint("task_reschedule_ti_fkey", "foreignkey")
        batch_op.drop_index("idx_task_reschedule_dag_task_run")

    # Change task_instance's primary key.
    with op.batch_alter_table("task_instance") as batch_op:
        # I think we always use this name for TaskInstance after 7b2661a43ba3?
        batch_op.drop_constraint("task_instance_pkey", type_="primary")
        batch_op.add_column(
            Column("map_index", Integer, nullable=False, default=-1))
        batch_op.create_primary_key(
            "task_instance_pkey", ["dag_id", "task_id", "run_id", "map_index"])

    # Re-create task_reschedule's constraints.
    with op.batch_alter_table("task_reschedule") as batch_op:
        batch_op.add_column(
            Column("map_index", Integer, nullable=False, default=-1))
        batch_op.create_foreign_key(
            "task_reschedule_ti_fkey",
            "task_instance",
            ["dag_id", "task_id", "run_id", "map_index"],
            ["dag_id", "task_id", "run_id", "map_index"],
            ondelete="CASCADE",
        )
        batch_op.create_index(
            "idx_task_reschedule_dag_task_run",
            ["dag_id", "task_id", "run_id", "map_index"],
            unique=False,
        )

    # Create task_map.
    op.create_table(
        "task_map",
        Column("dag_id", StringID(), primary_key=True),
        Column("task_id", StringID(), primary_key=True),
        Column("run_id", StringID(), primary_key=True),
        Column("map_index", Integer, primary_key=True),
        Column("length", Integer, nullable=False),
        Column("keys", ExtendedJSON, nullable=True),
        ForeignKeyConstraint(
            ["dag_id", "task_id", "run_id", "map_index"],
            [
                "task_instance.dag_id",
                "task_instance.task_id",
                "task_instance.run_id",
                "task_instance.map_index",
            ],
            name="task_map_task_instance_fkey",
            ondelete="CASCADE",
        ),
    )
Exemplo n.º 4
0
class Log(Base):
    """Used to actively log events to the database"""

    __tablename__ = "log"

    id = Column(Integer, primary_key=True)
    dttm = Column(UtcDateTime)
    dag_id = Column(StringID())
    task_id = Column(StringID())
    map_index = Column(Integer)
    event = Column(String(30))
    execution_date = Column(UtcDateTime)
    owner = Column(String(500))
    extra = Column(Text)

    __table_args__ = (
        Index('idx_log_dag', dag_id),
        Index('idx_log_event', event),
    )

    def __init__(self,
                 event,
                 task_instance=None,
                 owner=None,
                 extra=None,
                 **kwargs):
        self.dttm = timezone.utcnow()
        self.event = event
        self.extra = extra

        task_owner = None

        if task_instance:
            self.dag_id = task_instance.dag_id
            self.task_id = task_instance.task_id
            self.execution_date = task_instance.execution_date
            self.map_index = task_instance.map_index
            if task_instance.task:
                task_owner = task_instance.task.owner

        if 'task_id' in kwargs:
            self.task_id = kwargs['task_id']
        if 'dag_id' in kwargs:
            self.dag_id = kwargs['dag_id']
        if kwargs.get('execution_date'):
            self.execution_date = kwargs['execution_date']
        if 'map_index' in kwargs:
            self.map_index = kwargs['map_index']

        self.owner = owner or task_owner
Exemplo n.º 5
0
class TaskFail(Base):
    """TaskFail tracks the failed run durations of each task instance."""

    __tablename__ = "task_fail"

    id = Column(Integer, primary_key=True)
    task_id = Column(StringID(), nullable=False)
    dag_id = Column(StringID(), nullable=False)
    run_id = Column(StringID(), nullable=False)
    map_index = Column(Integer, nullable=False)
    start_date = Column(UtcDateTime)
    end_date = Column(UtcDateTime)
    duration = Column(Integer)

    __table_args__ = (ForeignKeyConstraint(
        [dag_id, task_id, run_id, map_index],
        [
            "task_instance.dag_id",
            "task_instance.task_id",
            "task_instance.run_id",
            "task_instance.map_index",
        ],
        name='task_fail_ti_fkey',
        ondelete="CASCADE",
    ), )

    def __init__(self, task, run_id, start_date, end_date, map_index):
        self.dag_id = task.dag_id
        self.task_id = task.task_id
        self.run_id = run_id
        self.map_index = map_index
        self.start_date = start_date
        self.end_date = end_date
        if self.end_date and self.start_date:
            self.duration = int(
                (self.end_date - self.start_date).total_seconds())
        else:
            self.duration = None

    def __repr__(self):
        prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
        if self.map_index != -1:
            prefix += f" map_index={self.map_index}"
        return prefix + '>'
Exemplo n.º 6
0
class DagWarning(Base):
    """
    A table to store DAG warnings.

    DAG warnings are problems that don't rise to the level of failing the DAG parse
    but which users should nonetheless be warned about.  These warnings are recorded
    when parsing DAG and displayed on the Webserver in a flash message.
    """

    dag_id = Column(StringID(), primary_key=True)
    warning_type = Column(String(50), primary_key=True)
    message = Column(Text, nullable=False)
    timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow)

    __tablename__ = "dag_warning"
    __table_args__ = (ForeignKeyConstraint(
        ('dag_id', ),
        ['dag.dag_id'],
        name='dcw_dag_id_fkey',
        ondelete='CASCADE',
    ), )

    def __init__(self, dag_id, error_type, message, **kwargs):
        super().__init__(**kwargs)
        self.dag_id = dag_id
        self.warning_type = DagWarningType(
            error_type).value  # make sure valid type
        self.message = message

    def __eq__(self, other):
        return self.dag_id == other.dag_id and self.warning_type == other.warning_type

    def __hash__(self):
        return hash((self.dag_id, self.warning_type))

    @classmethod
    @provide_session
    def purge_inactive_dag_warnings(cls, session=NEW_SESSION):
        """
        Deactivate DagWarning records for inactive dags.

        :return: None
        """
        from airflow.models.dag import DagModel

        if session.get_bind().dialect.name == 'sqlite':
            dag_ids = session.query(DagModel).filter(
                DagModel.is_active == false()).all()
            session.query(cls).filter(
                cls.dag_id.in_(dag_ids)).delete(synchronize_session=False)
        else:
            session.query(cls).filter(cls.dag_id == DagModel.dag_id,
                                      DagModel.is_active == false()).delete(
                                          synchronize_session=False)
        session.commit()
Exemplo n.º 7
0
class DatasetDagRunQueue(Base):
    """Model for storing dataset events that need processing."""

    dataset_id = Column(Integer, primary_key=True, nullable=False)
    target_dag_id = Column(StringID(), primary_key=True, nullable=False)
    created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

    __tablename__ = "dataset_dag_run_queue"
    __table_args__ = (
        PrimaryKeyConstraint(dataset_id,
                             target_dag_id,
                             name="datasetdagrunqueue_pkey",
                             mssql_clustered=True),
        ForeignKeyConstraint(
            (dataset_id, ),
            ["dataset.id"],
            name='ddrq_dataset_fkey',
            ondelete="CASCADE",
        ),
        ForeignKeyConstraint(
            (target_dag_id, ),
            ["dag.dag_id"],
            name='ddrq_dag_fkey',
            ondelete="CASCADE",
        ),
    )

    def __eq__(self, other):
        if isinstance(other, self.__class__):
            return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
        else:
            return NotImplemented

    def __hash__(self):
        return hash(self.__mapper__.primary_key)

    def __repr__(self):
        args = []
        for attr in [x.name for x in self.__mapper__.primary_key]:
            args.append(f"{attr}={getattr(self, attr)!r}")
        return f"{self.__class__.__name__}({', '.join(args)})"
Exemplo n.º 8
0
class RenderedTaskInstanceFields(Base):
    """Save Rendered Template Fields"""

    __tablename__ = "rendered_task_instance_fields"

    dag_id = Column(StringID(), primary_key=True)
    task_id = Column(StringID(), primary_key=True)
    run_id = Column(StringID(), primary_key=True)
    map_index = Column(Integer, primary_key=True, server_default='-1')
    rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json),
                             nullable=False)
    k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json),
                          nullable=True)

    __table_args__ = (ForeignKeyConstraint(
        [dag_id, task_id, run_id, map_index],
        [
            "task_instance.dag_id",
            "task_instance.task_id",
            "task_instance.run_id",
            "task_instance.map_index",
        ],
        name='rtif_ti_fkey',
        ondelete="CASCADE",
    ), )
    task_instance = relationship(
        "TaskInstance",
        lazy='joined',
        back_populates="rendered_task_instance_fields",
    )

    # We don't need a DB level FK here, as we already have that to TI (which has one to DR) but by defining
    # the relationship we can more easily find the execution date for these rows
    dag_run = relationship(
        "DagRun",
        primaryjoin="""and_(
            RenderedTaskInstanceFields.dag_id == foreign(DagRun.dag_id),
            RenderedTaskInstanceFields.run_id == foreign(DagRun.run_id),
        )""",
        viewonly=True,
    )

    execution_date = association_proxy("dag_run", "execution_date")

    def __init__(self, ti: TaskInstance, render_templates=True):
        self.dag_id = ti.dag_id
        self.task_id = ti.task_id
        self.run_id = ti.run_id
        self.map_index = ti.map_index
        self.ti = ti
        if render_templates:
            ti.render_templates()
        self.task = ti.task
        if os.environ.get("AIRFLOW_IS_K8S_EXECUTOR_POD", None):
            self.k8s_pod_yaml = ti.render_k8s_pod_yaml()
        self.rendered_fields = {
            field: serialize_template_field(getattr(self.task, field))
            for field in self.task.template_fields
        }

        self._redact()

    def __repr__(self):
        prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
        if self.map_index != -1:
            prefix += f" map_index={self.map_index}"
        return prefix + '>'

    def _redact(self):
        from airflow.utils.log.secrets_masker import redact

        if self.k8s_pod_yaml:
            self.k8s_pod_yaml = redact(self.k8s_pod_yaml)

        for field, rendered in self.rendered_fields.items():
            self.rendered_fields[field] = redact(rendered, field)

    @classmethod
    @provide_session
    def get_templated_fields(cls,
                             ti: TaskInstance,
                             session: Session = NEW_SESSION) -> Optional[dict]:
        """
        Get templated field for a TaskInstance from the RenderedTaskInstanceFields
        table.

        :param ti: Task Instance
        :param session: SqlAlchemy Session
        :return: Rendered Templated TI field
        """
        result = (session.query(cls.rendered_fields).filter(
            cls.dag_id == ti.dag_id,
            cls.task_id == ti.task_id,
            cls.run_id == ti.run_id,
            cls.map_index == ti.map_index,
        ).one_or_none())

        if result:
            rendered_fields = result.rendered_fields
            return rendered_fields
        else:
            return None

    @classmethod
    @provide_session
    def get_k8s_pod_yaml(cls,
                         ti: TaskInstance,
                         session: Session = NEW_SESSION) -> Optional[dict]:
        """
        Get rendered Kubernetes Pod Yaml for a TaskInstance from the RenderedTaskInstanceFields
        table.

        :param ti: Task Instance
        :param session: SqlAlchemy Session
        :return: Kubernetes Pod Yaml
        """
        result = (session.query(cls.k8s_pod_yaml).filter(
            cls.dag_id == ti.dag_id,
            cls.task_id == ti.task_id,
            cls.run_id == ti.run_id,
            cls.map_index == ti.map_index,
        ).one_or_none())
        return result.k8s_pod_yaml if result else None

    @provide_session
    def write(self, session: Session = None):
        """Write instance to database

        :param session: SqlAlchemy Session
        """
        session.merge(self)

    @classmethod
    @provide_session
    def delete_old_records(
        cls,
        task_id: str,
        dag_id: str,
        num_to_keep=conf.getint("core",
                                "max_num_rendered_ti_fields_per_task",
                                fallback=0),
        session: Session = None,
    ):
        """
        Keep only Last X (num_to_keep) number of records for a task by deleting others.

        In the case of data for a mapped task either all of the rows or none of the rows will be deleted, so
        we don't end up with partial data for a set of mapped Task Instances left in the database.

        :param task_id: Task ID
        :param dag_id: Dag ID
        :param num_to_keep: Number of Records to keep
        :param session: SqlAlchemy Session
        """
        from airflow.models.dagrun import DagRun

        if num_to_keep <= 0:
            return

        tis_to_keep_query = (session.query(
            cls.dag_id, cls.task_id, cls.run_id).filter(
                cls.dag_id == dag_id,
                cls.task_id == task_id).join(cls.dag_run).distinct().order_by(
                    DagRun.execution_date.desc()).limit(num_to_keep))

        if session.bind.dialect.name in ["postgresql", "sqlite"]:
            # Fetch Top X records given dag_id & task_id ordered by Execution Date
            subq1 = tis_to_keep_query.subquery()
            excluded = session.query(subq1.c.dag_id, subq1.c.task_id,
                                     subq1.c.run_id)
            session.query(cls).filter(
                cls.dag_id == dag_id,
                cls.task_id == task_id,
                tuple_(cls.dag_id, cls.task_id, cls.run_id).notin_(excluded),
            ).delete(synchronize_session=False)
        elif session.bind.dialect.name in ["mysql"]:
            cls._remove_old_rendered_ti_fields_mysql(dag_id, session, task_id,
                                                     tis_to_keep_query)
        else:
            # Fetch Top X records given dag_id & task_id ordered by Execution Date
            tis_to_keep = tis_to_keep_query.all()

            filter_tis = [
                not_(
                    and_(
                        cls.dag_id == ti.dag_id,
                        cls.task_id == ti.task_id,
                        cls.run_id == ti.run_id,
                    )) for ti in tis_to_keep
            ]

            session.query(cls).filter(
                and_(*filter_tis)).delete(synchronize_session=False)

        session.flush()

    @classmethod
    @retry_db_transaction
    def _remove_old_rendered_ti_fields_mysql(cls, dag_id, session, task_id,
                                             tis_to_keep_query):
        # Fetch Top X records given dag_id & task_id ordered by Execution Date
        subq1 = tis_to_keep_query.subquery('subq1')
        # Second Subquery
        # Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525)
        # Limitation: This version of MySQL does not yet support
        # LIMIT & IN/ALL/ANY/SOME subquery
        subq2 = session.query(subq1.c.dag_id, subq1.c.task_id,
                              subq1.c.run_id).subquery('subq2')
        # This query might deadlock occasionally and it should be retried if fails (see decorator)
        session.query(cls).filter(
            cls.dag_id == dag_id,
            cls.task_id == task_id,
            tuple_(cls.dag_id, cls.task_id, cls.run_id).notin_(subq2),
        ).delete(synchronize_session=False)
Exemplo n.º 9
0
class DatasetEvent(Base):
    """
    A table to store datasets events.

    :param dataset_id: reference to DatasetModel record
    :param extra: JSON field for arbitrary extra info
    :param source_task_id: the task_id of the TI which updated the dataset
    :param source_dag_id: the dag_id of the TI which updated the dataset
    :param source_run_id: the run_id of the TI which updated the dataset
    :param source_map_index: the map_index of the TI which updated the dataset
    :param timestamp: the time the event was logged

    We use relationships instead of foreign keys so that dataset events are not deleted even
    if the foreign key object is.
    """

    id = Column(Integer, primary_key=True, autoincrement=True)
    dataset_id = Column(Integer, nullable=False)
    extra = Column(sqlalchemy_jsonfield.JSONField(json=json),
                   nullable=False,
                   default={})
    source_task_id = Column(StringID(), nullable=True)
    source_dag_id = Column(StringID(), nullable=True)
    source_run_id = Column(StringID(), nullable=True)
    source_map_index = Column(Integer,
                              nullable=True,
                              server_default=text("-1"))
    timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)

    __tablename__ = "dataset_event"
    __table_args__ = (
        Index('idx_dataset_id_timestamp', dataset_id, timestamp),
        {
            'sqlite_autoincrement': True
        },  # ensures PK values not reused
    )

    created_dagruns = relationship(
        "DagRun",
        secondary=association_table,
        backref="consumed_dataset_events",
    )

    source_task_instance = relationship(
        "TaskInstance",
        primaryjoin="""and_(
            DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
            DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
            DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
            DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
        )""",
        viewonly=True,
        lazy="select",
        uselist=False,
    )
    source_dag_run = relationship(
        "DagRun",
        primaryjoin="""and_(
            DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
            DatasetEvent.source_run_id == foreign(DagRun.run_id),
        )""",
        viewonly=True,
        lazy="select",
        uselist=False,
    )
    dataset = relationship(
        DatasetModel,
        primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
        viewonly=True,
        lazy="select",
        uselist=False,
    )

    @property
    def uri(self):
        return self.dataset.uri

    def __eq__(self, other) -> bool:
        if isinstance(other, self.__class__):
            return self.dataset_id == other.dataset_id and self.timestamp == other.timestamp
        else:
            return NotImplemented

    def __hash__(self) -> int:
        return hash((self.dataset_id, self.created_at))

    def __repr__(self) -> str:
        args = []
        for attr in [
                'id',
                'dataset_id',
                'extra',
                'source_task_id',
                'source_dag_id',
                'source_run_id',
                'source_map_index',
        ]:
            args.append(f"{attr}={getattr(self, attr)!r}")
        return f"{self.__class__.__name__}({', '.join(args)})"