コード例 #1
0
def _get_ti(
    task: BaseOperator,
    exec_date_or_run_id: str,
    *,
    create_if_necessary: bool = False,
    session: Session = NEW_SESSION,
) -> TaskInstance:
    """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
    dag_run = _get_dag_run(
        dag=task.dag,
        exec_date_or_run_id=exec_date_or_run_id,
        create_if_necessary=create_if_necessary,
        session=session,
    )

    ti_or_none = dag_run.get_task_instance(task.task_id)
    if ti_or_none is None:
        if not create_if_necessary:
            raise TaskInstanceNotFound(
                f"TaskInstance for {task.dag.dag_id}, {task.task_id} with "
                f"run_id or execution_date of {exec_date_or_run_id!r} not found"
            )
        ti = TaskInstance(task, run_id=dag_run.run_id)
        ti.dag_run = dag_run
    else:
        ti = ti_or_none
    ti.refresh_from_task(task)
    return ti
コード例 #2
0
def generate_pod_yaml(args):
    """Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor"""
    execution_date = args.execution_date
    dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
    yaml_output_path = args.output_path
    dr = DagRun(dag.dag_id, execution_date=execution_date)
    kube_config = KubeConfig()
    for task in dag.tasks:
        ti = TaskInstance(task, None)
        ti.dag_run = dr
        pod = PodGenerator.construct_pod(
            dag_id=args.dag_id,
            task_id=ti.task_id,
            pod_id=create_pod_id(args.dag_id, ti.task_id),
            try_number=ti.try_number,
            kube_image=kube_config.kube_image,
            date=ti.execution_date,
            args=ti.command_as_list(),
            pod_override_object=PodGenerator.from_obj(ti.executor_config),
            scheduler_job_id="worker-config",
            namespace=kube_config.executor_namespace,
            base_worker_pod=PodGenerator.deserialize_model_file(kube_config.pod_template_file),
        )
        pod_mutation_hook(pod)
        api_client = ApiClient()
        date_string = pod_generator.datetime_to_label_safe_datestring(execution_date)
        yaml_file_name = f"{args.dag_id}_{ti.task_id}_{date_string}.yml"
        os.makedirs(os.path.dirname(yaml_output_path + "/airflow_yaml_output/"), exist_ok=True)
        with open(yaml_output_path + "/airflow_yaml_output/" + yaml_file_name, "w") as output:
            sanitized_pod = api_client.sanitize_for_serialization(pod)
            output.write(yaml.dump(sanitized_pod))
    print(f"YAML output can be found at {yaml_output_path}/airflow_yaml_output/")
コード例 #3
0
def _get_ti(
    task: BaseOperator,
    exec_date_or_run_id: str,
    map_index: int,
    *,
    create_if_necessary: bool = False,
    session: Session = NEW_SESSION,
) -> TaskInstance:
    """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
    if task.is_mapped:
        if map_index < 0:
            raise RuntimeError("No map_index passed to mapped task")
    elif map_index >= 0:
        raise RuntimeError("map_index passed to non-mapped task")
    dag_run = _get_dag_run(
        dag=task.dag,
        exec_date_or_run_id=exec_date_or_run_id,
        create_if_necessary=create_if_necessary,
        session=session,
    )

    ti_or_none = dag_run.get_task_instance(task.task_id, map_index=map_index, session=session)
    if ti_or_none is None:
        if not create_if_necessary:
            raise TaskInstanceNotFound(
                f"TaskInstance for {task.dag.dag_id}, {task.task_id}, map={map_index} with "
                f"run_id or execution_date of {exec_date_or_run_id!r} not found"
            )
        # TODO: Validate map_index is in range?
        ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
        ti.dag_run = dag_run
    else:
        ti = ti_or_none
    ti.refresh_from_task(task)
    return ti
コード例 #4
0
def _get_ti(task, exec_date_or_run_id, create_if_necessary=False, session=None):
    """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
    dag_run = _get_dag_run(task.dag, exec_date_or_run_id, create_if_necessary, session)

    ti = dag_run.get_task_instance(task.task_id)
    if not ti and create_if_necessary:
        ti = TaskInstance(task, run_id=None)
        ti.dag_run = dag_run
    ti.refresh_from_task(task)
    return ti
コード例 #5
0
def create_context(task):
    dag = DAG(dag_id="dag")
    tzinfo = pendulum.timezone("Europe/Amsterdam")
    execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
    dag_run = DagRun(dag_id=dag.dag_id, execution_date=execution_date)
    task_instance = TaskInstance(task=task)
    task_instance.dag_run = dag_run
    task_instance.xcom_push = mock.Mock()
    return {
        "dag": dag,
        "ts": execution_date.isoformat(),
        "task": task,
        "ti": task_instance,
        "task_instance": task_instance,
    }
コード例 #6
0
def create_context(task):
    dag = DAG(dag_id="dag")
    tzinfo = pendulum.timezone("Europe/Amsterdam")
    execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
    dag_run = DagRun(
        dag_id=dag.dag_id,
        execution_date=execution_date,
        run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
    )
    task_instance = TaskInstance(task=task)
    task_instance.dag_run = dag_run
    task_instance.dag_id = dag.dag_id
    task_instance.xcom_push = mock.Mock()
    return {
        "dag": dag,
        "run_id": dag_run.run_id,
        "task": task,
        "ti": task_instance,
        "task_instance": task_instance,
    }
コード例 #7
0
def dag_backfill(args, dag=None):
    """Creates backfill job or dry run for a DAG"""
    logging.basicConfig(level=settings.LOGGING_LEVEL,
                        format=settings.SIMPLE_LOG_FORMAT)

    signal.signal(signal.SIGTERM, sigint_handler)

    import warnings

    warnings.warn(
        '--ignore-first-depends-on-past is deprecated as the value is always set to True',
        category=PendingDeprecationWarning,
    )

    if args.ignore_first_depends_on_past is False:
        args.ignore_first_depends_on_past = True

    if not args.start_date and not args.end_date:
        raise AirflowException("Provide a start_date and/or end_date")

    dag = dag or get_dag(args.subdir, args.dag_id)

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.partial_subset(task_ids_or_regex=args.task_regex,
                                 include_upstream=not args.ignore_dependencies)
        if not dag.task_dict:
            raise AirflowException(
                f"There are no tasks that match '{args.task_regex}' regex. Nothing to run, exiting..."
            )

    run_conf = None
    if args.conf:
        run_conf = json.loads(args.conf)

    if args.dry_run:
        print(f"Dry run of DAG {args.dag_id} on {args.start_date}")
        dr = DagRun(dag.dag_id, execution_date=args.start_date)
        for task in dag.tasks:
            print(f"Task {task.task_id}")
            ti = TaskInstance(task, run_id=None)
            ti.dag_run = dr
            ti.dry_run()
    else:
        if args.reset_dagruns:
            DAG.clear_dags(
                [dag],
                start_date=args.start_date,
                end_date=args.end_date,
                confirm_prompt=not args.yes,
                include_subdags=True,
                dag_run_state=DagRunState.QUEUED,
            )

        try:
            dag.run(
                start_date=args.start_date,
                end_date=args.end_date,
                mark_success=args.mark_success,
                local=args.local,
                donot_pickle=(args.donot_pickle
                              or conf.getboolean('core', 'donot_pickle')),
                ignore_first_depends_on_past=args.ignore_first_depends_on_past,
                ignore_task_deps=args.ignore_dependencies,
                pool=args.pool,
                delay_on_limit_secs=args.delay_on_limit,
                verbose=args.verbose,
                conf=run_conf,
                rerun_failed_tasks=args.rerun_failed_tasks,
                run_backwards=args.run_backwards,
                continue_on_failures=args.continue_on_failures,
            )
        except ValueError as vr:
            print(str(vr))
            sys.exit(1)
コード例 #8
0
ファイル: test_airflow.py プロジェクト: hsheth2/datahub
def test_lineage_backend_capture_executions(mock_emit, inlets, outlets):
    DEFAULT_DATE = datetime.datetime(2020, 5, 17)
    mock_emitter = Mock()
    mock_emit.return_value = mock_emitter
    # Using autospec on xcom_pull and xcom_push methods fails on Python 3.6.
    with mock.patch.dict(
            os.environ,
        {
            "AIRFLOW__LINEAGE__BACKEND":
            "datahub_provider.lineage.datahub.DatahubLineageBackend",
            "AIRFLOW__LINEAGE__DATAHUB_CONN_ID":
            datahub_rest_connection_config.conn_id,
            "AIRFLOW__LINEAGE__DATAHUB_KWARGS":
            json.dumps({
                "graceful_exceptions": False,
                "capture_executions": True
            }),
        },
    ), mock.patch("airflow.models.BaseOperator.xcom_pull"), mock.patch(
            "airflow.models.BaseOperator.xcom_push"), patch_airflow_connection(
                datahub_rest_connection_config):
        func = mock.Mock()
        func.__name__ = "foo"

        dag = DAG(dag_id="test_lineage_is_sent_to_backend",
                  start_date=DEFAULT_DATE)

        with dag:
            op1 = DummyOperator(
                task_id="task1_upstream",
                inlets=inlets,
                outlets=outlets,
            )
            op2 = DummyOperator(
                task_id="task2",
                inlets=inlets,
                outlets=outlets,
            )
            op1 >> op2

        # Airflow < 2.2 requires the execution_date parameter. Newer Airflow
        # versions do not require it, but will attempt to find the associated
        # run_id in the database if execution_date is provided. As such, we
        # must fake the run_id parameter for newer Airflow versions.
        if AIRFLOW_VERSION < packaging.version.parse("2.2.0"):
            ti = TaskInstance(task=op2, execution_date=DEFAULT_DATE)
            # Ignoring type here because DagRun state is just a sring at Airflow 1
            dag_run = DagRun(
                state="success",
                run_id=f"scheduled_{DEFAULT_DATE}")  # type: ignore
            ti.dag_run = dag_run
            ti.start_date = datetime.datetime.utcnow()
            ti.execution_date = DEFAULT_DATE

        else:
            from airflow.utils.state import DagRunState

            ti = TaskInstance(task=op2, run_id=f"test_airflow-{DEFAULT_DATE}")
            dag_run = DagRun(state=DagRunState.SUCCESS,
                             run_id=f"scheduled_{DEFAULT_DATE}")
            ti.dag_run = dag_run
            ti.start_date = datetime.datetime.utcnow()
            ti.execution_date = DEFAULT_DATE

        ctx1 = {
            "dag": dag,
            "task": op2,
            "ti": ti,
            "dag_run": dag_run,
            "task_instance": ti,
            "execution_date": DEFAULT_DATE,
            "ts": "2021-04-08T00:54:25.771575+00:00",
        }

        prep = prepare_lineage(func)
        prep(op2, ctx1)
        post = apply_lineage(func)
        post(op2, ctx1)

        # Verify that the inlets and outlets are registered and recognized by Airflow correctly,
        # or that our lineage backend forces it to.
        assert len(op2.inlets) == 1
        assert len(op2.outlets) == 1
        assert all(map(lambda let: isinstance(let, Dataset), op2.inlets))
        assert all(map(lambda let: isinstance(let, Dataset), op2.outlets))

        # Check that the right things were emitted.
        assert mock_emitter.emit.call_count == 17
        # Running further checks based on python version because args only exists in python 3.7+
        if sys.version_info[:3] > (3, 7):
            assert mock_emitter.method_calls[0].args[
                0].aspectName == "dataFlowInfo"
            assert (
                mock_emitter.method_calls[0].args[0].entityUrn ==
                "urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod)"
            )

            assert mock_emitter.method_calls[1].args[
                0].aspectName == "ownership"
            assert (
                mock_emitter.method_calls[1].args[0].entityUrn ==
                "urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod)"
            )

            assert mock_emitter.method_calls[2].args[
                0].aspectName == "globalTags"
            assert (
                mock_emitter.method_calls[2].args[0].entityUrn ==
                "urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod)"
            )

            assert mock_emitter.method_calls[3].args[
                0].aspectName == "dataJobInfo"
            assert (
                mock_emitter.method_calls[3].args[0].entityUrn ==
                "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)"
            )

            assert (mock_emitter.method_calls[4].args[0].aspectName ==
                    "dataJobInputOutput")
            assert (
                mock_emitter.method_calls[4].args[0].entityUrn ==
                "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)"
            )
            assert (
                mock_emitter.method_calls[4].args[0].aspect.inputDatajobs[0] ==
                "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task1_upstream)"
            )
            assert (
                mock_emitter.method_calls[4].args[0].aspect.inputDatasets[0] ==
                "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)"
            )
            assert (
                mock_emitter.method_calls[4].args[0].aspect.outputDatasets[0]
                ==
                "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableProduced,PROD)"
            )

            assert mock_emitter.method_calls[5].args[0].aspectName == "status"
            assert (
                mock_emitter.method_calls[5].args[0].entityUrn ==
                "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)"
            )

            assert mock_emitter.method_calls[6].args[0].aspectName == "status"
            assert (
                mock_emitter.method_calls[6].args[0].entityUrn ==
                "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableProduced,PROD)"
            )

            assert mock_emitter.method_calls[7].args[
                0].aspectName == "ownership"
            assert (
                mock_emitter.method_calls[7].args[0].entityUrn ==
                "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)"
            )

            assert mock_emitter.method_calls[8].args[
                0].aspectName == "globalTags"
            assert (
                mock_emitter.method_calls[8].args[0].entityUrn ==
                "urn:li:dataJob:(urn:li:dataFlow:(airflow,test_lineage_is_sent_to_backend,prod),task2)"
            )

            assert (mock_emitter.method_calls[9].args[0].aspectName ==
                    "dataProcessInstanceProperties")
            assert (
                mock_emitter.method_calls[9].args[0].entityUrn ==
                "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")

            assert (mock_emitter.method_calls[10].args[0].aspectName ==
                    "dataProcessInstanceRelationships")
            assert (
                mock_emitter.method_calls[10].args[0].entityUrn ==
                "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")
            assert (mock_emitter.method_calls[11].args[0].aspectName ==
                    "dataProcessInstanceInput")
            assert (
                mock_emitter.method_calls[11].args[0].entityUrn ==
                "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")
            assert (mock_emitter.method_calls[12].args[0].aspectName ==
                    "dataProcessInstanceOutput")
            assert (
                mock_emitter.method_calls[12].args[0].entityUrn ==
                "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")
            assert mock_emitter.method_calls[13].args[0].aspectName == "status"
            assert (
                mock_emitter.method_calls[13].args[0].entityUrn ==
                "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableConsumed,PROD)"
            )
            assert mock_emitter.method_calls[14].args[0].aspectName == "status"
            assert (
                mock_emitter.method_calls[14].args[0].entityUrn ==
                "urn:li:dataset:(urn:li:dataPlatform:snowflake,mydb.schema.tableProduced,PROD)"
            )
            assert (mock_emitter.method_calls[15].args[0].aspectName ==
                    "dataProcessInstanceRunEvent")
            assert (
                mock_emitter.method_calls[15].args[0].entityUrn ==
                "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")
            assert (mock_emitter.method_calls[16].args[0].aspectName ==
                    "dataProcessInstanceRunEvent")
            assert (
                mock_emitter.method_calls[16].args[0].entityUrn ==
                "urn:li:dataProcessInstance:b6375e5f5faeb543cfb5d7d8a47661fb")