Example #1
0
 def setUp(self):
     self._temp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR',
                                     self.get_temp_dir())
     dummy_dag = models.DAG(dag_id='my_component',
                            start_date=datetime.datetime(2019, 1, 1))
     self.checkcache_op = dummy_operator.DummyOperator(
         task_id='my_component.checkcache', dag=dummy_dag)
     self.tfx_python_op = dummy_operator.DummyOperator(
         task_id='my_component.pythonexec', dag=dummy_dag)
     self.tfx_docker_op = dummy_operator.DummyOperator(
         task_id='my_component.dockerexec', dag=dummy_dag)
     self.publishcache_op = dummy_operator.DummyOperator(
         task_id='my_component.publishcache', dag=dummy_dag)
     self.publishexec_op = dummy_operator.DummyOperator(
         task_id='my_component.publishexec', dag=dummy_dag)
     self._logger_config = logging_utils.LoggerConfig()
     self.parent_dag = airflow_pipeline.AirflowPipeline(
         pipeline_name='pipeline_name',
         start_date=datetime.datetime(2018, 1, 1),
         schedule_interval=None,
         pipeline_root='pipeline_root',
         metadata_db_root=self._temp_dir,
         metadata_connection_config=None,
         additional_pipeline_args=None,
         docker_operator_cfg=None,
         enable_cache=True)
     self.input_dict = {'i': [TfxType('i')]}
     self.output_dict = {'o': [TfxType('o')]}
     self.exec_properties = {'e': 'e'}
     self.driver_options = {'d': 'd'}
Example #2
0
 def setUp(self):
     dummy_dag = models.DAG(dag_id='my_component',
                            start_date=datetime.datetime(2019, 1, 1))
     self.checkcache_op = dummy_operator.DummyOperator(
         task_id='my_component.checkcache', dag=dummy_dag)
     self.tfx_python_op = dummy_operator.DummyOperator(
         task_id='my_component.pythonexec', dag=dummy_dag)
     self.tfx_docker_op = dummy_operator.DummyOperator(
         task_id='my_component.dockerexec', dag=dummy_dag)
     self.publishcache_op = dummy_operator.DummyOperator(
         task_id='my_component.publishcache', dag=dummy_dag)
     self.publishexec_op = dummy_operator.DummyOperator(
         task_id='my_component.publishexec', dag=dummy_dag)
     self.parent_dag = airflow_pipeline.AirflowPipeline(
         pipeline_name='pipeline_name',
         start_date=datetime.datetime(2018, 1, 1),
         schedule_interval=None,
         pipeline_root='pipeline_root',
         metadata_db_root='metadata_db_root',
         metadata_connection_config=None,
         additional_pipeline_args=None,
         docker_operator_cfg=None,
         enable_cache=True,
         log_root='log_root')
     self.input_dict = {'i': [TfxType('i')]}
     self.output_dict = {'o': [TfxType('o')]}
     self.exec_properties = {'e': 'e'}
     self.driver_options = {'d': 'd'}
Example #3
0
 def _update_input_dict_from_xcom(self, task_instance):
   """Determine actual input artifacts used from upstream tasks."""
   for key, input_list in self._input_dict.items():
     source = self._input_source_dict.get(key)
     if not source:
       continue
     xcom_result = task_instance.xcom_pull(
         key=source.key, dag_id=source.component_id)
     if not xcom_result:
       self._logger.warn('Cannot resolve xcom source from %s', source)
       continue
     resolved_list = json.loads(xcom_result)
     for index, resolved_json_dict in enumerate(resolved_list):
       input_list[index] = TfxType.parse_from_json_dict(resolved_json_dict)
Example #4
0
  def test_build_graph(self):
    r"""Tests building airflow DAG graph using add_node_to_graph().

    The dependency graph beside is as below:
                     component_one
                     /           \
                    /             \
          component_two         component_three
                    \             /
                     \           /
                     component_four
    """

    component_one = dummy_operator.DummyOperator(
        task_id='one', dag=self.pipeline)
    component_two = dummy_operator.DummyOperator(
        task_id='two', dag=self.pipeline)
    component_three = dummy_operator.DummyOperator(
        task_id='three', dag=self.pipeline)
    component_four = dummy_operator.DummyOperator(
        task_id='four', dag=self.pipeline)

    component_one_input_a = TfxType('i1a')
    component_one_input_b = TfxType('i1b')
    component_one_output_a = TfxType('o1a')
    component_one_output_b = TfxType('o1b')
    component_two_output = TfxType('o2')
    component_three_output = TfxType('o3')
    component_four_output = TfxType('o4')

    component_one_input_dict = {
        'i1a': [component_one_input_a],
        'i1b': [component_one_input_b]
    }
    component_one_output_dict = {
        'o1a': [component_one_output_a],
        'o1b': [component_one_output_b]
    }
    component_two_input_dict = {
        'i2a': [component_one_output_a],
        'i2b': [component_one_output_b]
    }
    component_two_output_dict = {'o2': [component_two_output]}
    component_three_input_dict = {
        'i3a': [component_one_output_a],
        'i3b': [component_one_output_b]
    }
    component_three_output_dict = {'o3': [component_two_output]}
    component_four_input_dict = {
        'i4a': [component_two_output],
        'i4b': [component_three_output]
    }
    component_four_output_dict = {'o4': [component_four_output]}

    self.pipeline.add_node_to_graph(
        component_one,
        consumes=component_one_input_dict.values(),
        produces=component_one_output_dict.values())
    self.pipeline.add_node_to_graph(
        component_two,
        consumes=component_two_input_dict.values(),
        produces=component_two_output_dict.values())
    self.pipeline.add_node_to_graph(
        component_three,
        consumes=component_three_input_dict.values(),
        produces=component_three_output_dict.values())
    self.pipeline.add_node_to_graph(
        component_four,
        consumes=component_four_input_dict.values(),
        produces=component_four_output_dict.values())

    self.assertItemsEqual(component_one.upstream_list, [])
    self.assertItemsEqual(component_two.upstream_list, [component_one])
    self.assertItemsEqual(component_three.upstream_list, [component_one])
    self.assertItemsEqual(component_four.upstream_list,
                          [component_two, component_three])