def setUp(self): self._temp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()) dummy_dag = models.DAG(dag_id='my_component', start_date=datetime.datetime(2019, 1, 1)) self.checkcache_op = dummy_operator.DummyOperator( task_id='my_component.checkcache', dag=dummy_dag) self.tfx_python_op = dummy_operator.DummyOperator( task_id='my_component.pythonexec', dag=dummy_dag) self.noop_sink_op = dummy_operator.DummyOperator( task_id='my_component.noop_sink', dag=dummy_dag) self.publishexec_op = dummy_operator.DummyOperator( task_id='my_component.publishexec', dag=dummy_dag) self._logger_config = logging_utils.LoggerConfig() self.parent_dag = airflow_pipeline.AirflowPipeline( pipeline_name='pipeline_name', start_date=datetime.datetime(2018, 1, 1), schedule_interval=None, pipeline_root='pipeline_root', metadata_db_root=self._temp_dir, metadata_connection_config=None, additional_pipeline_args=None, enable_cache=True) self.input_dict = {'i': [TfxArtifact('i')]} self.output_dict = {'o': [TfxArtifact('o')]} self.exec_properties = {'e': 'e'} self.driver_options = {'d': 'd'}
def _create_pipeline(): """Implements an example pipeline with a custom head component""" input_artifact = TfxArtifact(type_name="RandomTypeNameForInput") input_artifact.uri = os.path.join(_data_root, "input_example.txt") input_channel = channel.Channel(artifacts=[input_artifact], type_name="RandomTypeNameForInput") my_first_awesome_component = custom_component.CustomHeadComponent( input_example=input_channel, string_execution_parameter="My awesome string", integer_execution_parameter=42) return code2flow.create_pipeline(components=[my_first_awesome_component], enable_cache=True, metadata_db_root=_metadata_db_root)
def search_artifacts(self, artifact_name: Text, pipeline_name: Text, run_id: Text, producer_component_id: Text) -> List[TfxArtifact]: """Search artifacts that matches given info. Args: artifact_name: the name of the artifact that set by producer component. The name is logged both in artifacts and the events when the execution being published. pipeline_name: the name of the pipeline that produces the artifact run_id: the run id of the pipeline run that produces the artifact producer_component_id: the id of the component that produces the artifact Returns: A list of TfxArtifacts that matches the given info Raises: RuntimeError: when no matching execution is found given producer info. """ producer_execution = None matching_artifact_ids = set() for execution in self._store.get_executions(): if (execution.properties['pipeline_name'].string_value == pipeline_name and execution.properties['run_id'].string_value == run_id and execution.properties['component_id'].string_value == producer_component_id): producer_execution = execution if not producer_execution: raise RuntimeError( 'Cannot find matching execution with pipeline name {},' 'run id {} and component id {}'.format(pipeline_name, run_id, producer_component_id)) for event in self._store.get_events_by_execution_ids( [producer_execution.id]): if (event.type == metadata_store_pb2.Event.OUTPUT and event.path.steps[0].key == artifact_name): matching_artifact_ids.add(event.artifact_id) result_artifacts = [] for a in self._store.get_artifacts_by_id(list(matching_artifact_ids)): tfx_artifact = TfxArtifact(a.properties['type_name'].string_value) tfx_artifact.artifact = a result_artifacts.append(tfx_artifact) return result_artifacts
def _update_input_dict_from_xcom(self, task_instance): """Determine actual input artifacts used from upstream tasks.""" for key, input_list in self._input_dict.items(): source = self._input_source_dict.get(key) if not source: continue xcom_result = task_instance.xcom_pull( key=source.key, dag_id=source.component_id) if not xcom_result: self._logger.warn('Cannot resolve xcom source from %s', source) continue resolved_list = json.loads(xcom_result) for index, resolved_json_dict in enumerate(resolved_list): input_list[index] = TfxArtifact.parse_from_json_dict(resolved_json_dict)
def test_build_graph(self): r"""Tests building airflow DAG graph using add_node_to_graph(). The dependency graph beside is as below: component_one / \ / \ component_two component_three \ / \ / component_four """ component_one = dummy_operator.DummyOperator( task_id='one', dag=self.pipeline) component_two = dummy_operator.DummyOperator( task_id='two', dag=self.pipeline) component_three = dummy_operator.DummyOperator( task_id='three', dag=self.pipeline) component_four = dummy_operator.DummyOperator( task_id='four', dag=self.pipeline) component_one_input_a = TfxArtifact('i1a') component_one_input_b = TfxArtifact('i1b') component_one_output_a = TfxArtifact('o1a') component_one_output_b = TfxArtifact('o1b') component_two_output = TfxArtifact('o2') component_three_output = TfxArtifact('o3') component_four_output = TfxArtifact('o4') component_one_input_dict = { 'i1a': [component_one_input_a], 'i1b': [component_one_input_b] } component_one_output_dict = { 'o1a': [component_one_output_a], 'o1b': [component_one_output_b] } component_two_input_dict = { 'i2a': [component_one_output_a], 'i2b': [component_one_output_b] } component_two_output_dict = {'o2': [component_two_output]} component_three_input_dict = { 'i3a': [component_one_output_a], 'i3b': [component_one_output_b] } component_three_output_dict = {'o3': [component_two_output]} component_four_input_dict = { 'i4a': [component_two_output], 'i4b': [component_three_output] } component_four_output_dict = {'o4': [component_four_output]} self.pipeline.add_node_to_graph( component_one, consumes=component_one_input_dict.values(), produces=component_one_output_dict.values()) self.pipeline.add_node_to_graph( component_two, consumes=component_two_input_dict.values(), produces=component_two_output_dict.values()) self.pipeline.add_node_to_graph( component_three, consumes=component_three_input_dict.values(), produces=component_three_output_dict.values()) self.pipeline.add_node_to_graph( component_four, consumes=component_four_input_dict.values(), produces=component_four_output_dict.values()) self.assertItemsEqual(component_one.upstream_list, []) self.assertItemsEqual(component_two.upstream_list, [component_one]) self.assertItemsEqual(component_three.upstream_list, [component_one]) self.assertItemsEqual(component_four.upstream_list, [component_two, component_three])