def setUp(self): self._temp_dir = os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()) dummy_dag = models.DAG(dag_id='my_component', start_date=datetime.datetime(2019, 1, 1)) self.checkcache_op = dummy_operator.DummyOperator( task_id='my_component.checkcache', dag=dummy_dag) self.tfx_python_op = dummy_operator.DummyOperator( task_id='my_component.pythonexec', dag=dummy_dag) self.tfx_docker_op = dummy_operator.DummyOperator( task_id='my_component.dockerexec', dag=dummy_dag) self.publishcache_op = dummy_operator.DummyOperator( task_id='my_component.publishcache', dag=dummy_dag) self.publishexec_op = dummy_operator.DummyOperator( task_id='my_component.publishexec', dag=dummy_dag) self._logger_config = logging_utils.LoggerConfig() self.parent_dag = airflow_pipeline.AirflowPipeline( pipeline_name='pipeline_name', start_date=datetime.datetime(2018, 1, 1), schedule_interval=None, pipeline_root='pipeline_root', metadata_db_root=self._temp_dir, metadata_connection_config=None, additional_pipeline_args=None, docker_operator_cfg=None, enable_cache=True) self.input_dict = {'i': [TfxType('i')]} self.output_dict = {'o': [TfxType('o')]} self.exec_properties = {'e': 'e'} self.driver_options = {'d': 'd'}
def setUp(self): dummy_dag = models.DAG(dag_id='my_component', start_date=datetime.datetime(2019, 1, 1)) self.checkcache_op = dummy_operator.DummyOperator( task_id='my_component.checkcache', dag=dummy_dag) self.tfx_python_op = dummy_operator.DummyOperator( task_id='my_component.pythonexec', dag=dummy_dag) self.tfx_docker_op = dummy_operator.DummyOperator( task_id='my_component.dockerexec', dag=dummy_dag) self.publishcache_op = dummy_operator.DummyOperator( task_id='my_component.publishcache', dag=dummy_dag) self.publishexec_op = dummy_operator.DummyOperator( task_id='my_component.publishexec', dag=dummy_dag) self.parent_dag = airflow_pipeline.AirflowPipeline( pipeline_name='pipeline_name', start_date=datetime.datetime(2018, 1, 1), schedule_interval=None, pipeline_root='pipeline_root', metadata_db_root='metadata_db_root', metadata_connection_config=None, additional_pipeline_args=None, docker_operator_cfg=None, enable_cache=True, log_root='log_root') self.input_dict = {'i': [TfxType('i')]} self.output_dict = {'o': [TfxType('o')]} self.exec_properties = {'e': 'e'} self.driver_options = {'d': 'd'}
def _update_input_dict_from_xcom(self, task_instance): """Determine actual input artifacts used from upstream tasks.""" for key, input_list in self._input_dict.items(): source = self._input_source_dict.get(key) if not source: continue xcom_result = task_instance.xcom_pull( key=source.key, dag_id=source.component_id) if not xcom_result: self._logger.warn('Cannot resolve xcom source from %s', source) continue resolved_list = json.loads(xcom_result) for index, resolved_json_dict in enumerate(resolved_list): input_list[index] = TfxType.parse_from_json_dict(resolved_json_dict)
def test_build_graph(self): r"""Tests building airflow DAG graph using add_node_to_graph(). The dependency graph beside is as below: component_one / \ / \ component_two component_three \ / \ / component_four """ component_one = dummy_operator.DummyOperator( task_id='one', dag=self.pipeline) component_two = dummy_operator.DummyOperator( task_id='two', dag=self.pipeline) component_three = dummy_operator.DummyOperator( task_id='three', dag=self.pipeline) component_four = dummy_operator.DummyOperator( task_id='four', dag=self.pipeline) component_one_input_a = TfxType('i1a') component_one_input_b = TfxType('i1b') component_one_output_a = TfxType('o1a') component_one_output_b = TfxType('o1b') component_two_output = TfxType('o2') component_three_output = TfxType('o3') component_four_output = TfxType('o4') component_one_input_dict = { 'i1a': [component_one_input_a], 'i1b': [component_one_input_b] } component_one_output_dict = { 'o1a': [component_one_output_a], 'o1b': [component_one_output_b] } component_two_input_dict = { 'i2a': [component_one_output_a], 'i2b': [component_one_output_b] } component_two_output_dict = {'o2': [component_two_output]} component_three_input_dict = { 'i3a': [component_one_output_a], 'i3b': [component_one_output_b] } component_three_output_dict = {'o3': [component_two_output]} component_four_input_dict = { 'i4a': [component_two_output], 'i4b': [component_three_output] } component_four_output_dict = {'o4': [component_four_output]} self.pipeline.add_node_to_graph( component_one, consumes=component_one_input_dict.values(), produces=component_one_output_dict.values()) self.pipeline.add_node_to_graph( component_two, consumes=component_two_input_dict.values(), produces=component_two_output_dict.values()) self.pipeline.add_node_to_graph( component_three, consumes=component_three_input_dict.values(), produces=component_three_output_dict.values()) self.pipeline.add_node_to_graph( component_four, consumes=component_four_input_dict.values(), produces=component_four_output_dict.values()) self.assertItemsEqual(component_one.upstream_list, []) self.assertItemsEqual(component_two.upstream_list, [component_one]) self.assertItemsEqual(component_three.upstream_list, [component_one]) self.assertItemsEqual(component_four.upstream_list, [component_two, component_three])