示例#1
0
文件: dag.py 项目: eijidepaz/marquez
 def _get_location(task):
     try:
         if hasattr(task, 'file_path') and task.file_path:
             return get_location(task.file_path)
         else:
             return get_location(task.dag.fileloc)
     except Exception:
         return None
示例#2
0
 def _get_location(task):
     try:
         if hasattr(task, 'file_path') and task.file_path:
             return get_location(task.file_path)
         else:
             return get_location(task.dag.fileloc)
     except Exception:
         log.warning(f"Failed to get location for task '{task.task_id}'.",
                     exc_info=True)
         return None
示例#3
0
def test_marquez_dag_with_extractor_returning_two_steps(
        job_id_mapping,
        mock_get_or_create_openlineage_client,
        clear_db_airflow_dags,
        session=None):

    # --- test setup
    dag_id = 'test_marquez_dag_with_extractor_returning_two_steps'
    dag = DAG(
        dag_id,
        schedule_interval='@daily',
        default_args=DAG_DEFAULT_ARGS,
        description=DAG_DESCRIPTION
    )

    dag_run_id = 'test_marquez_dag_with_extractor_returning_two_steps_run_id'
    run_id = f"{dag_run_id}.{TASK_ID_COMPLETED}"

    # Mock the marquez client method calls
    mock_marquez_client = mock.Mock()
    mock_get_or_create_openlineage_client.return_value = mock_marquez_client

    # Add task that will be marked as completed
    task_will_complete = TestFixtureDummyOperator(
        task_id=TASK_ID_COMPLETED,
        dag=dag
    )
    completed_task_location = get_location(task_will_complete.dag.fileloc)

    # Add the dummy extractor to the list for the task above
    _DAG_EXTRACTORS[task_will_complete.__class__] = TestFixtureDummyExtractorWithMultipleSteps

    # --- pretend run the DAG

    # Create DAG run and mark as running
    dagrun = dag.create_dagrun(
        run_id=dag_run_id,
        execution_date=DEFAULT_DATE,
        state=State.RUNNING)

    # --- Asserts that the job starting triggers openlineage event

    start_time = '2016-01-01T00:00:00.000000Z'
    end_time = '2016-01-02T00:00:00.000000Z'

    mock_marquez_client.emit.assert_called_once_with(
        RunEvent(
            RunState.START,
            mock.ANY,
            Run(run_id, {"nominalTime": NominalTimeRunFacet(start_time, end_time)}),
            Job("default", f"{dag_id}.{TASK_ID_COMPLETED}", {
                "documentation": DocumentationJobFacet(DAG_DESCRIPTION),
                "sourceCodeLocation": SourceCodeLocationJobFacet("", completed_task_location)
            }),
            PRODUCER,
            [OpenLineageDataset(DAG_NAMESPACE, 'extract_input1', {
                "dataSource": DataSourceDatasetFacet(
                    name='dummy_source_name',
                    uri='http://dummy/source/url'
                )
            })],
            []
        )
    )

    mock_marquez_client.reset_mock()

    # --- Pretend complete the task
    job_id_mapping.pop.return_value = run_id

    task_will_complete.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    dag.handle_callback(dagrun, success=True, session=session)

    # --- Assert that the openlineage call is done

    mock_marquez_client.emit.assert_called_once_with(
        RunEvent(
            RunState.COMPLETE,
            mock.ANY,
            Run(run_id),
            Job("default", f"{dag_id}.{TASK_ID_COMPLETED}"),
            PRODUCER,
            [OpenLineageDataset(DAG_NAMESPACE, 'extract_input1', {
                "dataSource": DataSourceDatasetFacet(
                    name='dummy_source_name',
                    uri='http://dummy/source/url'
                )
            })],
            []
        )
    )
示例#4
0
def test_marquez_dag(job_id_mapping, mock_get_or_create_openlineage_client,
                     clear_db_airflow_dags, session=None):

    dag = DAG(
        DAG_ID,
        schedule_interval='@daily',
        default_args=DAG_DEFAULT_ARGS,
        description=DAG_DESCRIPTION
    )
    # (1) Mock the marquez client method calls
    mock_marquez_client = mock.Mock()
    mock_get_or_create_openlineage_client.return_value = mock_marquez_client
    run_id_completed = f"{DAG_RUN_ID}.{TASK_ID_COMPLETED}"
    run_id_failed = f"{DAG_RUN_ID}.{TASK_ID_FAILED}"
    # mock_uuid.side_effect = [run_id_completed, run_id_failed]

    # (2) Add task that will be marked as completed
    task_will_complete = DummyOperator(
        task_id=TASK_ID_COMPLETED,
        dag=dag
    )
    completed_task_location = get_location(task_will_complete.dag.fileloc)

    # (3) Add task that will be marked as failed
    task_will_fail = DummyOperator(
        task_id=TASK_ID_FAILED,
        dag=dag
    )
    failed_task_location = get_location(task_will_complete.dag.fileloc)

    # (4) Create DAG run and mark as running
    dagrun = dag.create_dagrun(
        run_id=DAG_RUN_ID,
        execution_date=DEFAULT_DATE,
        state=State.RUNNING)

    # Assert emit calls
    start_time = '2016-01-01T00:00:00.000000Z'
    end_time = '2016-01-02T00:00:00.000000Z'

    emit_calls = [
        mock.call(RunEvent(
            eventType=RunState.START,
            eventTime=mock.ANY,
            run=Run(run_id_completed, {"nominalTime": NominalTimeRunFacet(start_time, end_time)}),
            job=Job("default", f"{DAG_ID}.{TASK_ID_COMPLETED}", {
                "documentation": DocumentationJobFacet(DAG_DESCRIPTION),
                "sourceCodeLocation": SourceCodeLocationJobFacet("", completed_task_location)
            }),
            producer=PRODUCER,
            inputs=[],
            outputs=[]
        )),
        mock.call(RunEvent(
            eventType=RunState.START,
            eventTime=mock.ANY,
            run=Run(run_id_failed, {"nominalTime": NominalTimeRunFacet(start_time, end_time)}),
            job=Job("default", f"{DAG_ID}.{TASK_ID_FAILED}", {
                "documentation": DocumentationJobFacet(DAG_DESCRIPTION),
                "sourceCodeLocation": SourceCodeLocationJobFacet("", failed_task_location)
            }),
            producer=PRODUCER,
            inputs=[],
            outputs=[]
        ))
    ]
    log.info(
        f"{ [name for name, args, kwargs in mock_marquez_client.mock_calls]}")
    mock_marquez_client.emit.assert_has_calls(emit_calls)

    # (5) Start task that will be marked as completed
    task_will_complete.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    # (6) Start task that will be marked as failed
    ti1 = TaskInstance(task=task_will_fail, execution_date=DEFAULT_DATE)
    ti1.state = State.FAILED
    session.add(ti1)
    session.commit()

    job_id_mapping.pop.side_effect = [run_id_completed, run_id_failed]

    dag.handle_callback(dagrun, success=False, session=session)

    emit_calls += [
        mock.call(RunEvent(
            eventType=RunState.COMPLETE,
            eventTime=mock.ANY,
            run=Run(run_id_completed),
            job=Job("default", f"{DAG_ID}.{TASK_ID_COMPLETED}"),
            producer=PRODUCER,
            inputs=[],
            outputs=[]
        )),
        mock.call(RunEvent(
            eventType=RunState.FAIL,
            eventTime=mock.ANY,
            run=Run(run_id_failed),
            job=Job("default", f"{DAG_ID}.{TASK_ID_FAILED}"),
            producer=PRODUCER,
            inputs=[],
            outputs=[]
        ))
    ]
    mock_marquez_client.emit.assert_has_calls(emit_calls)
示例#5
0
def test_get_location_no_file_path():
    assert get_location(None) is None
    assert get_location("") is None
示例#6
0
def test_marquez_dag_with_extract_on_complete(
        job_id_mapping,
        mock_get_or_create_openlineage_client,
        clear_db_airflow_dags,
        session=None):

    # --- test setup
    dag_id = 'test_marquez_dag_with_extractor_on_complete'
    dag = DAG(
        dag_id,
        schedule_interval='@daily',
        default_args=DAG_DEFAULT_ARGS,
        description=DAG_DESCRIPTION
    )

    dag_run_id = 'test_marquez_dag_with_extractor_run_id'
    run_id = f"{dag_run_id}.{TASK_ID_COMPLETED}"
    # Mock the marquez client method calls
    mock_marquez_client = mock.Mock()
    mock_get_or_create_openlineage_client.return_value = mock_marquez_client

    # Add task that will be marked as completed
    task_will_complete = TestFixtureDummyOperator(
        task_id=TASK_ID_COMPLETED,
        dag=dag
    )
    completed_task_location = get_location(task_will_complete.dag.fileloc)

    # Add the dummy extractor to the list for the task above
    _DAG_EXTRACTORS[task_will_complete.__class__] = \
        TestFixtureDummyExtractorOnComplete

    # Create DAG run and mark as running
    dagrun = dag.create_dagrun(
        run_id=dag_run_id,
        execution_date=DEFAULT_DATE,
        state=State.RUNNING)

    start_time = '2016-01-01T00:00:00.000000Z'
    end_time = '2016-01-02T00:00:00.000000Z'

    mock_marquez_client.emit.assert_has_calls([
        mock.call(RunEvent(
            eventType=RunState.START,
            eventTime=mock.ANY,
            run=Run(run_id, {
                "nominalTime": NominalTimeRunFacet(start_time, end_time)
            }),
            job=Job("default",  f"{dag_id}.{TASK_ID_COMPLETED}", {
                "documentation": DocumentationJobFacet(DAG_DESCRIPTION),
                "sourceCodeLocation": SourceCodeLocationJobFacet("", completed_task_location)
            }),
            producer=PRODUCER,
            inputs=[],
            outputs=[]
        ))
    ])

    mock_marquez_client.reset_mock()

    # --- Pretend complete the task
    job_id_mapping.pop.return_value = run_id

    task_will_complete.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    dag.handle_callback(dagrun, success=True, session=session)

    mock_marquez_client.emit.assert_has_calls([
        mock.call(RunEvent(
            eventType=RunState.COMPLETE,
            eventTime=mock.ANY,
            run=Run(run_id),
            job=Job("default", f"{dag_id}.{TASK_ID_COMPLETED}"),
            producer=PRODUCER,
            inputs=[OpenLineageDataset(
                namespace='default',
                name='schema.extract_on_complete_input1',
                facets={
                    'dataSource': DataSourceDatasetFacet(
                        name='dummy_source_name',
                        uri='http://dummy/source/url'
                    ),
                    'schema': SchemaDatasetFacet(
                        fields=[
                            SchemaField(name='field1', type='text', description=''),
                            SchemaField(name='field2', type='text', description='')
                        ]
                    )
                })
            ],
            outputs=[OpenLineageDataset(
                namespace='default',
                name='extract_on_complete_output1',
                facets={
                    'dataSource': DataSourceDatasetFacet(
                        name='dummy_source_name',
                        uri='http://dummy/source/url'
                    )
                })
            ]
        ))
    ])
示例#7
0
    def report_task(self, dag_run_id, execution_date, run_args, task,
                    extractor):

        report_job_start_ms = self._now_ms()
        marquez_client = self.get_marquez_client()
        if execution_date:
            start_time = self._to_iso_8601(execution_date)
            end_time = self.compute_endtime(execution_date)
        else:
            start_time = None
            end_time = None

        if end_time:
            end_time = self._to_iso_8601(end_time)

        task_location = None
        try:
            if hasattr(task, 'file_path') and task.file_path:
                task_location = get_location(task.file_path)
            else:
                task_location = get_location(task.dag.fileloc)
        except Exception:
            log.warn('Unable to fetch the location')

        steps_metadata = []
        if extractor:
            try:
                log.info(f'Using extractor {extractor.__name__}',
                         task_type=task.__class__.__name__,
                         airflow_dag_id=self.dag_id,
                         task_id=task.task_id,
                         airflow_run_id=dag_run_id,
                         marquez_namespace=self.marquez_namespace)
                steps_metadata = extractor(task).extract()
            except Exception as e:
                log.error(f'Failed to extract metadata {e}',
                          airflow_dag_id=self.dag_id,
                          task_id=task.task_id,
                          airflow_run_id=dag_run_id,
                          marquez_namespace=self.marquez_namespace)
        else:
            log.warn('Unable to find an extractor.',
                     task_type=task.__class__.__name__,
                     airflow_dag_id=self.dag_id,
                     task_id=task.task_id,
                     airflow_run_id=dag_run_id,
                     marquez_namespace=self.marquez_namespace)

        task_name = f'{self.dag_id}.{task.task_id}'

        # If no extractor found or failed to extract metadata,
        # report the task metadata
        if not steps_metadata:
            steps_metadata = [StepMetadata(task_name)]

        # store all the JobRuns associated with a task
        marquez_jobrun_ids = []

        for step in steps_metadata:
            input_datasets = []
            output_datasets = []

            try:
                input_datasets = self.register_datasets(step.inputs)
            except Exception as e:
                log.error(f'Failed to register inputs: {e}',
                          inputs=str(step.inputs),
                          airflow_dag_id=self.dag_id,
                          task_id=task.task_id,
                          step=step.name,
                          airflow_run_id=dag_run_id,
                          marquez_namespace=self.marquez_namespace)
            try:
                output_datasets = self.register_datasets(step.outputs)
            except Exception as e:
                log.error(f'Failed to register outputs: {e}',
                          outputs=str(step.outputs),
                          airflow_dag_id=self.dag_id,
                          task_id=task.task_id,
                          step=step.name,
                          airflow_run_id=dag_run_id,
                          marquez_namespace=self.marquez_namespace)

            marquez_client.create_job(
                job_name=step.name,
                job_type='BATCH',  # job type
                location=(step.location or task_location),
                input_dataset=input_datasets,
                output_dataset=output_datasets,
                context=step.context,
                description=self.description,
                namespace_name=self.marquez_namespace)
            log.info(f'Successfully recorded job: {step.name}',
                     airflow_dag_id=self.dag_id,
                     marquez_namespace=self.marquez_namespace)

            marquez_jobrun_id = marquez_client.create_job_run(
                step.name,
                run_args=run_args,
                nominal_start_time=start_time,
                nominal_end_time=end_time).get('runId')

            if marquez_jobrun_id:
                marquez_jobrun_ids.append(marquez_jobrun_id)
                marquez_client.mark_job_run_as_started(marquez_jobrun_id)
            else:
                log.error(f'Failed to get run id: {step.name}',
                          airflow_dag_id=self.dag_id,
                          airflow_run_id=dag_run_id,
                          marquez_namespace=self.marquez_namespace)
            log.info(f'Successfully recorded job run: {step.name}',
                     airflow_dag_id=self.dag_id,
                     airflow_dag_execution_time=start_time,
                     marquez_run_id=marquez_jobrun_id,
                     marquez_namespace=self.marquez_namespace,
                     duration_ms=(self._now_ms() - report_job_start_ms))

        # Store the mapping for all the steps associated with a task
        try:
            self._job_id_mapping.set(
                JobIdMapping.make_key(task_name, dag_run_id),
                json.dumps(marquez_jobrun_ids))

        except Exception as e:
            log.error(f'Failed to set id mapping : {e}',
                      airflow_dag_id=self.dag_id,
                      task_id=task.task_id,
                      airflow_run_id=dag_run_id,
                      marquez_run_id=marquez_jobrun_ids,
                      marquez_namespace=self.marquez_namespace)
示例#8
0
def test_marquez_dag(mock_get_or_create_marquez_client,
                     mock_uuid,
                     clear_db_airflow_dags,
                     session=None):

    dag = DAG(DAG_ID,
              schedule_interval='@daily',
              default_args=DAG_DEFAULT_ARGS,
              description=DAG_DESCRIPTION)
    # (1) Mock the marquez client method calls
    mock_marquez_client = mock.Mock()
    mock_get_or_create_marquez_client.return_value = mock_marquez_client
    run_id_completed = "my-test_marquez_dag-uuid-completed"
    run_id_failed = "my-test_marquez_dag-uuid-failed"
    mock_uuid.side_effect = [run_id_completed, run_id_failed]

    # (2) Add task that will be marked as completed
    task_will_complete = DummyOperator(task_id=TASK_ID_COMPLETED, dag=dag)
    completed_task_location = get_location(task_will_complete.dag.fileloc)

    # (3) Add task that will be marked as failed
    task_will_fail = DummyOperator(task_id=TASK_ID_FAILED, dag=dag)
    failed_task_location = get_location(task_will_complete.dag.fileloc)

    # (4) Create DAG run and mark as running
    dagrun = dag.create_dagrun(run_id=DAG_RUN_ID,
                               execution_date=DEFAULT_DATE,
                               state=State.RUNNING)

    # Assert namespace meta call
    mock_marquez_client.create_namespace.assert_called_once_with(
        DAG_NAMESPACE, DAG_OWNER)

    # Assert source and dataset meta calls
    mock_marquez_client.create_source.assert_not_called()
    mock_marquez_client.create_dataset.assert_not_called()

    # Assert job meta calls
    create_job_calls = [
        mock.call(job_name=f"{DAG_ID}.{TASK_ID_COMPLETED}",
                  job_type=JobType.BATCH,
                  location=completed_task_location,
                  input_dataset=None,
                  output_dataset=None,
                  context=mock.ANY,
                  description=DAG_DESCRIPTION,
                  namespace_name=DAG_NAMESPACE,
                  run_id=None),
        mock.call(job_name=f"{DAG_ID}.{TASK_ID_FAILED}",
                  job_type=JobType.BATCH,
                  location=failed_task_location,
                  input_dataset=None,
                  output_dataset=None,
                  context=mock.ANY,
                  description=DAG_DESCRIPTION,
                  namespace_name=DAG_NAMESPACE,
                  run_id=None)
    ]
    log.info(
        f"{ [name for name, args, kwargs in mock_marquez_client.mock_calls]}")
    mock_marquez_client.create_job.assert_has_calls(create_job_calls)

    # Assert job run meta calls
    create_job_run_calls = [
        mock.call(job_name=f"{DAG_ID}.{TASK_ID_COMPLETED}",
                  run_id=mock.ANY,
                  run_args=DAG_RUN_ARGS,
                  nominal_start_time=mock.ANY,
                  nominal_end_time=mock.ANY,
                  namespace_name=DAG_NAMESPACE),
        mock.call(job_name=f"{DAG_ID}.{TASK_ID_FAILED}",
                  run_id=mock.ANY,
                  run_args=DAG_RUN_ARGS,
                  nominal_start_time=mock.ANY,
                  nominal_end_time=mock.ANY,
                  namespace_name=DAG_NAMESPACE)
    ]
    mock_marquez_client.create_job_run.assert_has_calls(create_job_run_calls)

    # (5) Start task that will be marked as completed
    task_will_complete.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    # (6) Start task that will be marked as failed
    ti1 = TaskInstance(task=task_will_fail, execution_date=DEFAULT_DATE)
    ti1.state = State.FAILED
    session.add(ti1)
    session.commit()

    dag.handle_callback(dagrun, success=True, session=session)

    # Assert start run meta calls
    start_job_run_calls = [
        mock.call(run_id_completed, mock.ANY),
        mock.call(run_id_failed, mock.ANY)
    ]
    mock_marquez_client.mark_job_run_as_started.assert_has_calls(
        start_job_run_calls)

    mock_marquez_client.mark_job_run_as_completed.assert_called_once_with(
        run_id=run_id_completed)

    # When a task run completes, the task outputs are also updated in order
    # to link a job version (=task version) to a dataset version.
    # Using a DummyOperator, no outputs exists, so assert that the create
    # dataset call is not invoked.
    mock_marquez_client.create_dataset.assert_not_called()

    dag.handle_callback(dagrun, success=False, session=session)
    mock_marquez_client.mark_job_run_as_failed.assert_called_once_with(
        run_id=run_id_failed)

    # Assert an attempt to version the outputs of a task is not made when
    # a task fails
    mock_marquez_client.create_dataset.assert_not_called()
示例#9
0
def test_marquez_dag_with_extract_on_complete(
        mock_get_or_create_marquez_client,
        mock_uuid,
        clear_db_airflow_dags,
        session=None):

    # --- test setup
    dag_id = 'test_marquez_dag_with_extractor'
    dag = DAG(dag_id,
              schedule_interval='@daily',
              default_args=DAG_DEFAULT_ARGS,
              description=DAG_DESCRIPTION)

    run_id = "my-test-uuid"
    mock_uuid.side_effect = [run_id]
    # Mock the marquez client method calls
    mock_marquez_client = mock.Mock()
    mock_get_or_create_marquez_client.return_value = mock_marquez_client

    # Add task that will be marked as completed
    task_will_complete = TestFixtureDummyOperator(task_id=TASK_ID_COMPLETED,
                                                  dag=dag)
    completed_task_location = get_location(task_will_complete.dag.fileloc)

    # Add the dummy extractor to the list for the task above
    dag._extractors[task_will_complete.__class__] = \
        TestFixtureDummyExtractorOnComplete

    # Create DAG run and mark as running
    dagrun = dag.create_dagrun(run_id='test_marquez_dag_with_extractor_run_id',
                               execution_date=DEFAULT_DATE,
                               state=State.RUNNING)

    # Namespace created
    mock_marquez_client.create_namespace.assert_called_once_with(
        DAG_NAMESPACE, DAG_OWNER)

    log.info("Marquez client calls when starting:")
    for call in mock_marquez_client.mock_calls:
        log.info(call)

    assert [name for name, args, kwargs in mock_marquez_client.mock_calls
            ] == ['create_namespace']
    mock_marquez_client.reset_mock()

    # --- Pretend complete the task
    task_will_complete.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)

    dag.handle_callback(dagrun, success=True, session=session)

    # Datasets are updated
    mock_marquez_client.create_source.assert_called_with(
        'dummy_source_name', 'DummySource', 'http://dummy/source/url')
    # Datasets get called twice, once to reenact the _begin_run_flow
    #  and then again at _end_run_flow w/ the run id appended for
    #  the output dataset
    mock_marquez_client.create_dataset.assert_has_calls([
        mock.call(dataset_name='schema.extract_on_complete_input1',
                  dataset_type=DatasetType.DB_TABLE,
                  physical_name='schema.extract_on_complete_input1',
                  source_name='dummy_source_name',
                  namespace_name=DAG_NAMESPACE,
                  fields=mock.ANY,
                  run_id=None),
        mock.call(dataset_name='extract_on_complete_output1',
                  dataset_type=DatasetType.DB_TABLE,
                  physical_name='extract_on_complete_output1',
                  source_name='dummy_source_name',
                  namespace_name=DAG_NAMESPACE,
                  fields=[],
                  run_id=None),
        mock.call(dataset_name='schema.extract_on_complete_input1',
                  dataset_type=DatasetType.DB_TABLE,
                  physical_name='schema.extract_on_complete_input1',
                  source_name='dummy_source_name',
                  namespace_name=DAG_NAMESPACE,
                  fields=mock.ANY,
                  run_id=None),
        mock.call(dataset_name='extract_on_complete_output1',
                  dataset_type=DatasetType.DB_TABLE,
                  physical_name='extract_on_complete_output1',
                  source_name='dummy_source_name',
                  namespace_name=DAG_NAMESPACE,
                  fields=[],
                  run_id='my-test-uuid')
    ])

    # job is updated
    mock_marquez_client.create_job.assert_has_calls([
        mock.call(job_name=f"{dag_id}.{TASK_ID_COMPLETED}",
                  job_type=JobType.BATCH,
                  location=completed_task_location,
                  input_dataset=[{
                      'namespace': 'default',
                      'name': 'schema.extract_on_complete_input1'
                  }],
                  output_dataset=[{
                      'namespace': 'default',
                      'name': 'extract_on_complete_output1'
                  }],
                  context=mock.ANY,
                  description=DAG_DESCRIPTION,
                  namespace_name=DAG_NAMESPACE,
                  run_id=None),
        mock.call(job_name=f"{dag_id}.{TASK_ID_COMPLETED}",
                  job_type=JobType.BATCH,
                  location=completed_task_location,
                  input_dataset=[{
                      'namespace': 'default',
                      'name': 'schema.extract_on_complete_input1'
                  }],
                  output_dataset=[{
                      'namespace': 'default',
                      'name': 'extract_on_complete_output1'
                  }],
                  context=mock.ANY,
                  description=DAG_DESCRIPTION,
                  namespace_name=DAG_NAMESPACE,
                  run_id='my-test-uuid')
    ])
    assert mock_marquez_client.create_job.mock_calls[0].\
        kwargs['context'].get('extract_on_complete') == 'extract_on_complete'

    # run is created
    mock_marquez_client.create_job_run.assert_called_once_with(
        job_name=f"{dag_id}.{TASK_ID_COMPLETED}",
        run_id=run_id,
        run_args=DAG_RUN_ARGS,
        nominal_start_time=mock.ANY,
        nominal_end_time=mock.ANY,
        namespace_name=DAG_NAMESPACE)

    # run is started
    mock_marquez_client.mark_job_run_as_started.assert_called_once_with(
        run_id, mock.ANY)

    # --- Assert that the right marquez calls are done

    # job is updated before completion
    mock_marquez_client.create_job.assert_has_calls([
        mock.call(namespace_name=DAG_NAMESPACE,
                  job_name=f"{dag_id}.{TASK_ID_COMPLETED}",
                  job_type=JobType.BATCH,
                  location=completed_task_location,
                  input_dataset=[{
                      'namespace': 'default',
                      'name': 'schema.extract_on_complete_input1'
                  }],
                  output_dataset=[{
                      'namespace': 'default',
                      'name': 'extract_on_complete_output1'
                  }],
                  context=mock.ANY,
                  description=DAG_DESCRIPTION,
                  run_id=run_id)
    ])

    assert mock_marquez_client.create_job.mock_calls[0].\
        kwargs['context'].get('extract_on_complete') == 'extract_on_complete'

    mock_marquez_client.mark_job_run_as_completed.assert_called_once_with(
        run_id=run_id)

    # When a task run completes, the task outputs are also updated in order
    # to link a job version (=task version) to a dataset version.
    mock_marquez_client.create_dataset.assert_has_calls([
        mock.call(dataset_name='schema.extract_on_complete_input1',
                  dataset_type=DatasetType.DB_TABLE,
                  physical_name='schema.extract_on_complete_input1',
                  source_name='dummy_source_name',
                  namespace_name=DAG_NAMESPACE,
                  fields=mock.ANY,
                  run_id=None),
        mock.call(dataset_name='extract_on_complete_output1',
                  dataset_type=DatasetType.DB_TABLE,
                  physical_name='extract_on_complete_output1',
                  source_name='dummy_source_name',
                  namespace_name=DAG_NAMESPACE,
                  fields=[],
                  run_id=run_id)
    ])

    log.info("Marquez client calls when completing:")
    for call in mock_marquez_client.mock_calls:
        log.info(call)
    assert [name for name, args, kwargs in mock_marquez_client.mock_calls] == [
        'create_namespace', 'create_source', 'create_dataset', 'create_source',
        'create_dataset', 'create_job', 'create_job_run', 'create_source',
        'create_dataset', 'create_source', 'create_dataset', 'create_job',
        'mark_job_run_as_started', 'mark_job_run_as_completed'
    ]
def test_bad_file_path(git_mock):
    with pytest.raises(FileNotFoundError):
        # invalid file
        get_location("dags/missing-dag.py")
def test_dag_location(git_mock):
    assert ('https://github.com/MarquezProject/marquez-airflow/blob/'
            'abcd1234/tests/test_dags/test_dag.py' == get_location(
                "tests/test_dags/test_dag.py"))