예제 #1
0
 def setUp(self):
     self._temp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                                     self.get_temp_dir())
     dummy_dag = models.DAG(dag_id='my_component',
                            start_date=datetime.datetime(2019, 1, 1))
     self.checkcache_op = dummy_operator.DummyOperator(
         task_id='my_component.checkcache', dag=dummy_dag)
     self.tfx_python_op = dummy_operator.DummyOperator(
         task_id='my_component.pythonexec', dag=dummy_dag)
     self.noop_sink_op = dummy_operator.DummyOperator(
         task_id='my_component.noop_sink', dag=dummy_dag)
     self.publishexec_op = dummy_operator.DummyOperator(
         task_id='my_component.publishexec', dag=dummy_dag)
     self._logger_config = logging_utils.LoggerConfig()
     self.parent_dag = airflow_pipeline.AirflowPipeline(
         pipeline_name='pipeline_name',
         start_date=datetime.datetime(2018, 1, 1),
         schedule_interval=None,
         pipeline_root='pipeline_root',
         metadata_db_root=self._temp_dir,
         metadata_connection_config=None,
         additional_pipeline_args=None,
         enable_cache=True)
     self.input_dict = {'i': [TfxArtifact('i')]}
     self.output_dict = {'o': [TfxArtifact('o')]}
     self.exec_properties = {'e': 'e'}
     self.driver_options = {'d': 'd'}
def _create_pipeline():
    """Implements an example pipeline with a custom head component"""

    input_artifact = TfxArtifact(type_name="RandomTypeNameForInput")
    input_artifact.uri = os.path.join(_data_root, "input_example.txt")

    input_channel = channel.Channel(artifacts=[input_artifact],
                                    type_name="RandomTypeNameForInput")

    my_first_awesome_component = custom_component.CustomHeadComponent(
        input_example=input_channel,
        string_execution_parameter="My awesome string",
        integer_execution_parameter=42)

    return code2flow.create_pipeline(components=[my_first_awesome_component],
                                     enable_cache=True,
                                     metadata_db_root=_metadata_db_root)
예제 #3
0
    def search_artifacts(self, artifact_name: Text, pipeline_name: Text,
                         run_id: Text,
                         producer_component_id: Text) -> List[TfxArtifact]:
        """Search artifacts that matches given info.

    Args:
      artifact_name: the name of the artifact that set by producer component.
        The name is logged both in artifacts and the events when the execution
        being published.
      pipeline_name: the name of the pipeline that produces the artifact
      run_id: the run id of the pipeline run that produces the artifact
      producer_component_id: the id of the component that produces the artifact

    Returns:
      A list of TfxArtifacts that matches the given info

    Raises:
      RuntimeError: when no matching execution is found given producer info.
    """
        producer_execution = None
        matching_artifact_ids = set()
        for execution in self._store.get_executions():
            if (execution.properties['pipeline_name'].string_value
                    == pipeline_name
                    and execution.properties['run_id'].string_value == run_id
                    and execution.properties['component_id'].string_value
                    == producer_component_id):
                producer_execution = execution
        if not producer_execution:
            raise RuntimeError(
                'Cannot find matching execution with pipeline name {},'
                'run id {} and component id {}'.format(pipeline_name, run_id,
                                                       producer_component_id))
        for event in self._store.get_events_by_execution_ids(
            [producer_execution.id]):
            if (event.type == metadata_store_pb2.Event.OUTPUT
                    and event.path.steps[0].key == artifact_name):
                matching_artifact_ids.add(event.artifact_id)

        result_artifacts = []
        for a in self._store.get_artifacts_by_id(list(matching_artifact_ids)):
            tfx_artifact = TfxArtifact(a.properties['type_name'].string_value)
            tfx_artifact.artifact = a
            result_artifacts.append(tfx_artifact)
        return result_artifacts
예제 #4
0
 def _update_input_dict_from_xcom(self, task_instance):
   """Determine actual input artifacts used from upstream tasks."""
   for key, input_list in self._input_dict.items():
     source = self._input_source_dict.get(key)
     if not source:
       continue
     xcom_result = task_instance.xcom_pull(
         key=source.key, dag_id=source.component_id)
     if not xcom_result:
       self._logger.warn('Cannot resolve xcom source from %s', source)
       continue
     resolved_list = json.loads(xcom_result)
     for index, resolved_json_dict in enumerate(resolved_list):
       input_list[index] = TfxArtifact.parse_from_json_dict(resolved_json_dict)
예제 #5
0
  def test_build_graph(self):
    r"""Tests building airflow DAG graph using add_node_to_graph().

    The dependency graph beside is as below:
                     component_one
                     /           \
                    /             \
          component_two         component_three
                    \             /
                     \           /
                     component_four
    """

    component_one = dummy_operator.DummyOperator(
        task_id='one', dag=self.pipeline)
    component_two = dummy_operator.DummyOperator(
        task_id='two', dag=self.pipeline)
    component_three = dummy_operator.DummyOperator(
        task_id='three', dag=self.pipeline)
    component_four = dummy_operator.DummyOperator(
        task_id='four', dag=self.pipeline)

    component_one_input_a = TfxArtifact('i1a')
    component_one_input_b = TfxArtifact('i1b')
    component_one_output_a = TfxArtifact('o1a')
    component_one_output_b = TfxArtifact('o1b')
    component_two_output = TfxArtifact('o2')
    component_three_output = TfxArtifact('o3')
    component_four_output = TfxArtifact('o4')

    component_one_input_dict = {
        'i1a': [component_one_input_a],
        'i1b': [component_one_input_b]
    }
    component_one_output_dict = {
        'o1a': [component_one_output_a],
        'o1b': [component_one_output_b]
    }
    component_two_input_dict = {
        'i2a': [component_one_output_a],
        'i2b': [component_one_output_b]
    }
    component_two_output_dict = {'o2': [component_two_output]}
    component_three_input_dict = {
        'i3a': [component_one_output_a],
        'i3b': [component_one_output_b]
    }
    component_three_output_dict = {'o3': [component_two_output]}
    component_four_input_dict = {
        'i4a': [component_two_output],
        'i4b': [component_three_output]
    }
    component_four_output_dict = {'o4': [component_four_output]}

    self.pipeline.add_node_to_graph(
        component_one,
        consumes=component_one_input_dict.values(),
        produces=component_one_output_dict.values())
    self.pipeline.add_node_to_graph(
        component_two,
        consumes=component_two_input_dict.values(),
        produces=component_two_output_dict.values())
    self.pipeline.add_node_to_graph(
        component_three,
        consumes=component_three_input_dict.values(),
        produces=component_three_output_dict.values())
    self.pipeline.add_node_to_graph(
        component_four,
        consumes=component_four_input_dict.values(),
        produces=component_four_output_dict.values())

    self.assertItemsEqual(component_one.upstream_list, [])
    self.assertItemsEqual(component_two.upstream_list, [component_one])
    self.assertItemsEqual(component_three.upstream_list, [component_one])
    self.assertItemsEqual(component_four.upstream_list,
                          [component_two, component_three])