def _report_task_instance(self, task_instance, dagrun, session): task = self.get_task(task_instance.task_id) # Note: task_run_id could be missing if it was removed from airflow # or the job could not be registered. task_run_id = JobIdMapping.pop( self._marquez_job_name_from_task_instance(task_instance), dagrun.run_id, session) step = self._extract_metadata(dagrun, task, task_instance) job_name = self._marquez_job_name(self.dag_id, task.task_id) run_id = self._marquez_run_id(dagrun.run_id, task.task_id) if not task_run_id: task_run_id = _MARQUEZ.start_task( run_id, job_name, self.description, DagUtils.to_iso_8601(task_instance.start_date), dagrun.run_id, self._get_location(task), DagUtils.to_iso_8601(task_instance.start_date), DagUtils.to_iso_8601(task_instance.end_date), step, {**step.run_facets, **get_custom_facets(task, False)} ) if not task_run_id: self.log.warning('Could not emit lineage') self.log.debug(f'Setting task state: {task_instance.state}' f' for {task_instance.task_id}') if task_instance.state in {State.SUCCESS, State.SKIPPED}: _MARQUEZ.complete_task( task_run_id, job_name, DagUtils.to_iso_8601(task_instance.end_date), step ) else: _MARQUEZ.fail_task( task_run_id, job_name, DagUtils.to_iso_8601(task_instance.end_date), step )
def _report_task_instance(self, ti, dagrun, run_args, session): task = self.get_task(ti.task_id) run_ids = JobIdMapping.pop(self._marquez_job_name_from_ti(ti), dagrun.run_id, session) steps = self._extract_metadata(dagrun, task, ti) # Note: run_ids could be missing if it was removed from airflow # or the job could not be registered. if not run_ids: [ _MARQUEZ.create_job(step, self._get_location(task), self.description) for step in steps ] run_ids = [ _MARQUEZ.create_run(self.new_run_id(), step, run_args, DagUtils.to_iso_8601(ti.start_date), DagUtils.to_iso_8601(ti.end_date)) for step in steps ] if not run_ids: self.log.warn('Could not emit lineage') for step in steps: for run_id in run_ids: _MARQUEZ.create_job(step, self._get_location(task), self.description, ti.state, run_id) _MARQUEZ.start_run(run_id, DagUtils.to_iso_8601(ti.start_date)) self.log.debug(f'Setting task state: {ti.state}' f' for {ti.task_id}') if ti.state in {State.SUCCESS, State.SKIPPED}: _MARQUEZ.complete_run(run_id, DagUtils.to_iso_8601(ti.end_date)) else: _MARQUEZ.fail_run(run_id, DagUtils.to_iso_8601(ti.end_date))
class DAG(airflow.models.DAG, LoggingMixin): _job_id_mapping = None _marquez = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._job_id_mapping = JobIdMapping() self._marquez = Marquez() # TODO: Manually define operator->extractor mappings for now, # but we'll want to encapsulate this logic in an 'Extractors' class # with more convenient methods (ex: 'Extractors.extractor_for_task()') self._extractors = { PostgresOperator: PostgresExtractor, BigQueryOperator: BigQueryExtractor # Append new extractors here } self.log.debug( f"DAG successfully created with extractors: {self._extractors}") def create_dagrun(self, *args, **kwargs): # run Airflow's create_dagrun() first dagrun = super(DAG, self).create_dagrun(*args, **kwargs) create_dag_start_ms = self._now_ms() try: self._marquez.create_namespace() self._register_dagrun(dagrun, DagUtils.get_execution_date(**kwargs), DagUtils.get_run_args(**kwargs)) except Exception as e: self.log.error( f'Failed to record metadata: {e} ' f'{self._timed_log_message(create_dag_start_ms)}', exc_info=True) return dagrun def _register_dagrun(self, dagrun, execution_date, run_args): self.log.debug(f"self.task_dict: {self.task_dict}") # Register each task in the DAG for task_id, task in self.task_dict.items(): t = self._now_ms() try: steps = self._extract_metadata(dagrun, task) [ self._marquez.create_job(step, self._get_location(task), self.description) for step in steps ] marquez_jobrun_ids = [ self._marquez.create_run( self.new_run_id(), step, run_args, DagUtils.get_start_time(execution_date), DagUtils.get_end_time( execution_date, self.following_schedule(execution_date))) for step in steps ] self._job_id_mapping.set( self._marquez_job_name(self.dag_id, task.task_id), dagrun.run_id, marquez_jobrun_ids) except Exception as e: self.log.error( f'Failed to record task {task_id}: {e} ' f'{self._timed_log_message(t)}', exc_info=True) def handle_callback(self, *args, **kwargs): self.log.debug(f"handle_callback({args}, {kwargs})") try: dagrun = args[0] self.log.debug(f"handle_callback() dagrun : {dagrun}") self._marquez.create_namespace() self._report_task_instances(dagrun, DagUtils.get_run_args(**kwargs), kwargs.get('session')) except Exception as e: self.log.error( f'Failed to record dagrun callback: {e} ' f'dag_id={self.dag_id}', exc_info=True) return super().handle_callback(*args) def _report_task_instances(self, dagrun, run_args, session): task_instances = dagrun.get_task_instances() for ti in task_instances: try: self._report_task_instance(ti, dagrun, run_args, session) except Exception as e: self.log.error( f'Failed to record task instance: {e} ' f'dag_id={self.dag_id}', exc_info=True) def _report_task_instance(self, ti, dagrun, run_args, session): task = self.get_task(ti.task_id) run_ids = self._job_id_mapping.pop(self._marquez_job_name_from_ti(ti), dagrun.run_id, session) steps = self._extract_metadata(dagrun, task, ti) # Note: run_ids could be missing if it was removed from airflow # or the job could not be registered. if not run_ids: [ self._marquez.create_job(step, self._get_location(task), self.description) for step in steps ] run_ids = [ self._marquez.create_run(self.new_run_id(), step, run_args, DagUtils.to_iso_8601(ti.start_date), DagUtils.to_iso_8601(ti.end_date)) for step in steps ] if not run_ids: self.log.warn('Could not emit lineage') for step in steps: for run_id in run_ids: self._marquez.create_job(step, self._get_location(task), self.description, ti.state, run_id) self._marquez.start_run(run_id, DagUtils.to_iso_8601(ti.start_date)) self.log.debug(f'Setting task state: {ti.state}' f' for {ti.task_id}') if ti.state in {State.SUCCESS, State.SKIPPED}: self._marquez.complete_run( run_id, DagUtils.to_iso_8601(ti.end_date)) else: self._marquez.fail_run(run_id, DagUtils.to_iso_8601(ti.end_date)) def _extract_metadata(self, dagrun, task, ti=None): extractor = self._get_extractor(task) task_info = f'task_type={task.__class__.__name__} ' \ f'airflow_dag_id={self.dag_id} ' \ f'task_id={task.task_id} ' \ f'airflow_run_id={dagrun.run_id} ' if extractor: try: self.log.debug( f'Using extractor {extractor.__name__} {task_info}') steps = self._extract(extractor, task, ti) return add_airflow_info_to(task, steps) except Exception as e: self.log.error(f'Failed to extract metadata {e} {task_info}', exc_info=True) else: self.log.warning(f'Unable to find an extractor. {task_info}') return add_airflow_info_to(task, [ StepMetadata( name=self._marquez_job_name(self.dag_id, task.task_id)) ]) def _extract(self, extractor, task, ti): if ti: steps = extractor(task).extract_on_complete(ti) if steps: return steps return extractor(task).extract() def _get_extractor(self, task): extractor = self._extractors.get(task.__class__) log.debug(f'extractor for {task.__class__} is {extractor}') return extractor def _timed_log_message(self, start_time): return f'airflow_dag_id={self.dag_id} ' \ f'duration_ms={(self._now_ms() - start_time)}' def new_run_id(self) -> str: return str(uuid4()) @staticmethod def _now_ms(): return int(round(time.time() * 1000)) @staticmethod def _get_location(task): try: if hasattr(task, 'file_path') and task.file_path: return get_location(task.file_path) else: return get_location(task.dag.fileloc) except Exception: log.warning(f"Failed to get location for task '{task.task_id}'.", exc_info=True) return None @staticmethod def _marquez_job_name_from_ti(ti): return DAG._marquez_job_name(ti.dag_id, ti.task_id) @staticmethod def _marquez_job_name(dag_id, task_id): return f'{dag_id}.{task_id}'
class DAG(airflow.models.DAG): DEFAULT_NAMESPACE = 'default' _job_id_mapping = None _marquez_client = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.marquez_namespace = os.environ.get('MARQUEZ_NAMESPACE') or \ DAG.DEFAULT_NAMESPACE self.marquez_location = kwargs['default_args'].get( 'marquez_location', 'unknown') self.marquez_input_urns = kwargs['default_args'].get( 'marquez_input_urns', []) self.marquez_output_urns = kwargs['default_args'].get( 'marquez_output_urns', []) self._job_id_mapping = JobIdMapping() def create_dagrun(self, *args, **kwargs): run_args = "{}" # TODO extract the run Args from the tasks marquez_jobrun_id = None try: marquez_jobrun_id = self.report_jobrun(run_args, kwargs['execution_date']) log.info(f'Successfully recorded job run.', airflow_dag_id=self.dag_id, marquez_run_id=marquez_jobrun_id, marquez_namespace=self.marquez_namespace) except Exception as e: log.error(f'Failed to record job run: {e}', airflow_dag_id=self.dag_id, marquez_namespace=self.marquez_namespace) pass run = super(DAG, self).create_dagrun(*args, **kwargs) if marquez_jobrun_id: try: self._job_id_mapping.set( JobIdMapping.make_key(run.dag_id, run.run_id), marquez_jobrun_id) except Exception as e: log.error(f'Failed job run lookup: {e}', airflow_dag_id=self.dag_id, airflow_run_id=run.run_id, marquez_run_id=marquez_jobrun_id, marquez_namespace=self.marquez_namespace) pass return run def handle_callback(self, *args, **kwargs): try: self.report_jobrun_change(args[0], **kwargs) except Exception as e: log.error(f'Failed to record job run state change: {e}', dag_id=self.dag_id) return super().handle_callback(*args, **kwargs) def report_jobrun(self, run_args, execution_date): now_ms = self._now_ms() job_name = self.dag_id start_time = execution_date.format("%Y-%m-%dT%H:%M:%SZ") end_time = self.compute_endtime(execution_date) if end_time: end_time = end_time.strftime("%Y-%m-%dT%H:%M:%SZ") marquez_client = self.get_marquez_client() marquez_client.create_job(job_name, self.marquez_location, self.marquez_input_urns, self.marquez_output_urns, description=self.description) log.info(f'Successfully recorded job: {job_name}', airflow_dag_id=self.dag_id, marquez_namespace=self.marquez_namespace) marquez_jobrun = marquez_client.create_job_run( job_name, run_args=run_args, nominal_start_time=start_time, nominal_end_time=end_time) marquez_jobrun_id = marquez_jobrun.get('runId') if marquez_jobrun_id: marquez_client.mark_job_run_as_running(marquez_jobrun_id) log.info(f'Successfully recorded job run: {job_name}', airflow_dag_id=self.dag_id, airflow_dag_execution_time=start_time, marquez_run_id=marquez_jobrun_id, marquez_namespace=self.marquez_namespace, duration_ms=(self._now_ms() - now_ms)) else: log.warn(f'Run id found not found: {job_name}', airflow_dag_id=self.dag_id, airflow_dag_execution_time=start_time, marquez_run_id=marquez_jobrun_id, marquez_namespace=self.marquez_namespace, duration_ms=(self._now_ms() - now_ms)) return marquez_jobrun_id def compute_endtime(self, execution_date): return self.following_schedule(execution_date) def report_jobrun_change(self, dagrun, **kwargs): session = kwargs.get('session') marquez_job_run_id = self._job_id_mapping.pop( JobIdMapping.make_key(dagrun.dag_id, dagrun.run_id), session) if marquez_job_run_id: log.info(f'Found job run.', airflow_dag_id=dagrun.dag_id, airflow_run_id=dagrun.run_id, marquez_run_id=marquez_job_run_id, marquez_namespace=self.marquez_namespace) if kwargs.get('success'): self.get_marquez_client().mark_job_run_as_completed( marquez_job_run_id) else: self.get_marquez_client().mark_job_run_as_failed( marquez_job_run_id) state = 'COMPLETED' if kwargs.get('success') else 'FAILED' log.info(f'Marked job run as {state}.', airflow_dag_id=dagrun.dag_id, airflow_run_id=dagrun.run_id, marquez_run_id=marquez_job_run_id, marquez_namespace=self.marquez_namespace) def get_marquez_client(self): if not self._marquez_client: self._marquez_client = MarquezClient( namespace_name=self.marquez_namespace) self._marquez_client.create_namespace(self.marquez_namespace, "default_owner") return self._marquez_client @staticmethod def _now_ms(): return int(round(time.time() * 1000))
class DAG(airflow.models.DAG): DEFAULT_NAMESPACE = 'default' _job_id_mapping = None _marquez_client = None def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._marquez_dataset_cache = {} self._marquez_source_cache = {} self.marquez_namespace = os.getenv('MARQUEZ_NAMESPACE', DAG.DEFAULT_NAMESPACE) self._job_id_mapping = JobIdMapping() def create_dagrun(self, *args, **kwargs): # run Airflow's create_dagrun() first dagrun = super(DAG, self).create_dagrun(*args, **kwargs) create_dag_start_ms = self._now_ms() execution_date = kwargs.get('execution_date') run_args = {'external_trigger': kwargs.get('external_trigger', False)} extractors = {} try: extractors = get_extractors() except Exception as e: log.warn(f'Failed retrieve extractors: {e}', airflow_dag_id=self.dag_id, marquez_namespace=self.marquez_namespace) # Marquez metadata collection try: marquez_client = self.get_marquez_client() # Create the Namespace marquez_client.create_namespace(self.marquez_namespace, "default_owner") # Register each task in the DAG for task_id, task in self.task_dict.items(): t = self._now_ms() try: self.report_task(dagrun.run_id, execution_date, run_args, task, extractors.get(task.__class__.__name__)) except Exception as e: log.error(f'Failed to record task: {e}', airflow_dag_id=self.dag_id, task_id=task_id, marquez_namespace=self.marquez_namespace, duration_ms=(self._now_ms() - t)) log.info('Successfully recorded metadata', airflow_dag_id=self.dag_id, marquez_namespace=self.marquez_namespace, duration_ms=(self._now_ms() - create_dag_start_ms)) except Exception as e: log.error(f'Failed to record metadata: {e}', airflow_dag_id=self.dag_id, marquez_namespace=self.marquez_namespace, duration_ms=(self._now_ms() - create_dag_start_ms)) return dagrun def handle_callback(self, *args, **kwargs): try: dagrun = args[0] task_instances = dagrun.get_task_instances() for ti in task_instances: try: job_name = f'{ti.dag_id}.{ti.task_id}' self.report_jobrun_change(job_name, dagrun.run_id, **kwargs) except Exception as e: log.error(f'Failed to record task run state change: {e}', dag_id=self.dag_id) except Exception as e: log.error(f'Failed to record dagrun state change: {e}', dag_id=self.dag_id) return super().handle_callback(*args, **kwargs) def report_task(self, dag_run_id, execution_date, run_args, task, extractor): report_job_start_ms = self._now_ms() marquez_client = self.get_marquez_client() if execution_date: start_time = self._to_iso_8601(execution_date) end_time = self.compute_endtime(execution_date) else: start_time = None end_time = None if end_time: end_time = self._to_iso_8601(end_time) task_location = None try: if hasattr(task, 'file_path') and task.file_path: task_location = get_location(task.file_path) else: task_location = get_location(task.dag.fileloc) except Exception: log.warn('Unable to fetch the location') steps_metadata = [] if extractor: try: log.info(f'Using extractor {extractor.__name__}', task_type=task.__class__.__name__, airflow_dag_id=self.dag_id, task_id=task.task_id, airflow_run_id=dag_run_id, marquez_namespace=self.marquez_namespace) steps_metadata = extractor(task).extract() except Exception as e: log.error(f'Failed to extract metadata {e}', airflow_dag_id=self.dag_id, task_id=task.task_id, airflow_run_id=dag_run_id, marquez_namespace=self.marquez_namespace) else: log.warn('Unable to find an extractor.', task_type=task.__class__.__name__, airflow_dag_id=self.dag_id, task_id=task.task_id, airflow_run_id=dag_run_id, marquez_namespace=self.marquez_namespace) task_name = f'{self.dag_id}.{task.task_id}' # If no extractor found or failed to extract metadata, # report the task metadata if not steps_metadata: steps_metadata = [StepMetadata(task_name)] # store all the JobRuns associated with a task marquez_jobrun_ids = [] for step in steps_metadata: input_datasets = [] output_datasets = [] try: input_datasets = self.register_datasets(step.inputs) except Exception as e: log.error(f'Failed to register inputs: {e}', inputs=str(step.inputs), airflow_dag_id=self.dag_id, task_id=task.task_id, step=step.name, airflow_run_id=dag_run_id, marquez_namespace=self.marquez_namespace) try: output_datasets = self.register_datasets(step.outputs) except Exception as e: log.error(f'Failed to register outputs: {e}', outputs=str(step.outputs), airflow_dag_id=self.dag_id, task_id=task.task_id, step=step.name, airflow_run_id=dag_run_id, marquez_namespace=self.marquez_namespace) marquez_client.create_job( job_name=step.name, job_type='BATCH', # job type location=(step.location or task_location), input_dataset=input_datasets, output_dataset=output_datasets, context=step.context, description=self.description, namespace_name=self.marquez_namespace) log.info(f'Successfully recorded job: {step.name}', airflow_dag_id=self.dag_id, marquez_namespace=self.marquez_namespace) marquez_jobrun_id = marquez_client.create_job_run( step.name, run_args=run_args, nominal_start_time=start_time, nominal_end_time=end_time).get('runId') if marquez_jobrun_id: marquez_jobrun_ids.append(marquez_jobrun_id) marquez_client.mark_job_run_as_started(marquez_jobrun_id) else: log.error(f'Failed to get run id: {step.name}', airflow_dag_id=self.dag_id, airflow_run_id=dag_run_id, marquez_namespace=self.marquez_namespace) log.info(f'Successfully recorded job run: {step.name}', airflow_dag_id=self.dag_id, airflow_dag_execution_time=start_time, marquez_run_id=marquez_jobrun_id, marquez_namespace=self.marquez_namespace, duration_ms=(self._now_ms() - report_job_start_ms)) # Store the mapping for all the steps associated with a task try: self._job_id_mapping.set( JobIdMapping.make_key(task_name, dag_run_id), json.dumps(marquez_jobrun_ids)) except Exception as e: log.error(f'Failed to set id mapping : {e}', airflow_dag_id=self.dag_id, task_id=task.task_id, airflow_run_id=dag_run_id, marquez_run_id=marquez_jobrun_ids, marquez_namespace=self.marquez_namespace) def compute_endtime(self, execution_date): return self.following_schedule(execution_date) def report_jobrun_change(self, job_name, run_id, **kwargs): session = kwargs.get('session') marquez_job_run_ids = self._job_id_mapping.pop( JobIdMapping.make_key(job_name, run_id), session) if marquez_job_run_ids: log.info('Found job runs.', airflow_dag_id=self.dag_id, airflow_job_id=job_name, airflow_run_id=run_id, marquez_run_ids=marquez_job_run_ids, marquez_namespace=self.marquez_namespace) ids = json.loads(marquez_job_run_ids) if kwargs.get('success'): for marquez_job_run_id in ids: self.get_marquez_client().mark_job_run_as_completed( marquez_job_run_id) else: for marquez_job_run_id in ids: self.get_marquez_client().mark_job_run_as_failed( marquez_job_run_id) state = 'COMPLETED' if kwargs.get('success') else 'FAILED' log.info(f'Marked job run(s) as {state}.', airflow_dag_id=self.dag_id, airflow_job_id=job_name, airflow_run_id=run_id, marquez_run_id=marquez_job_run_ids, marquez_namespace=self.marquez_namespace) def get_marquez_client(self): if not self._marquez_client: self._marquez_client = MarquezClient() return self._marquez_client @staticmethod def _now_ms(): return int(round(time.time() * 1000)) def register_datasets(self, datasets): dataset_names = [] if not datasets: return dataset_names client = self.get_marquez_client() for dataset in datasets: if isinstance(dataset, Dataset): _key = str(dataset) if _key not in self._marquez_dataset_cache: source_name = self.register_source(dataset.source) if source_name: dataset = client.create_dataset( dataset.name, dataset.type, dataset.name, # physical_name the same for now source_name, namespace_name=self.marquez_namespace) dataset_name = dataset.get('name') if dataset_name: self._marquez_dataset_cache[_key] = dataset_name dataset_names.append(dataset_name) else: dataset_names.append(self._marquez_dataset_cache[_key]) return dataset_names def register_source(self, source): if isinstance(source, Source): _key = str(source) if _key in self._marquez_source_cache: return self._marquez_source_cache[_key] client = self.get_marquez_client() ds = client.create_source(source.name, source.type, source.connection_url) source_name = ds.get('name') self._marquez_source_cache[_key] = source_name return source_name @staticmethod def _to_iso_8601(dt): if isinstance(dt, Pendulum): return dt.format(_NOMINAL_TIME_FORMAT) else: return dt.strftime(_NOMINAL_TIME_FORMAT)