def test_add_to_dag_fragment(self): with DAG(dag_id='d1', default_args=DEFAULT_DAG_ARGS) as dag: op1 = DummyOperator(task_id='d1t1') op2 = DummyOperator(task_id='d1t2') op3 = DummyOperator(task_id='d1t3') op4 = DummyOperator(task_id='d1t4') op5 = DummyOperator(task_id='d1t5') op6 = DummyOperator(task_id='d1t6') op1 >> [op2, op3] >> op4 op2 >> [op5, op6] down_op1 = DummyOperator(task_id='d2t1') down_op2 = DummyOperator(task_id='d2t2') down_op3 = DummyOperator(task_id='d2t3') down_op4 = DummyOperator(task_id='d2t4') down_op5 = DummyOperator(task_id='d2t5') [down_op1, down_op2] >> down_op3 >> [down_op4, down_op5] frag_up: DAGFragment = DAGFragment([op1]) frag_down: DAGFragment = DAGFragment([down_op1, down_op2]) TransformerUtils.add_downstream_dag_fragment(frag_up, frag_down) self.assertTrue(down_op1 in op4.downstream_list) self.assertTrue(down_op2 in op4.downstream_list) self.assertTrue(down_op1 in op5.downstream_list) self.assertTrue(down_op2 in op5.downstream_list) self.assertTrue(down_op1 in op6.downstream_list) self.assertTrue(down_op2 in op6.downstream_list)
def test_find_sub_dag(self): with DAG(dag_id='d1', default_args=DEFAULT_DAG_ARGS) as dag: def fn0(): print("hi") def fn1(): math.pow(1, 2) def fn2(): math.factorial(1) op1 = DummyOperator(task_id='op1') op2 = DummyOperator(task_id='op2') op3 = PythonOperator(task_id='op3', python_callable=fn0) op4 = DummyOperator(task_id='op4') op1 >> [op2, op3] >> op4 op5 = PythonOperator(task_id='op5', python_callable=fn1) op6 = PythonOperator(task_id='op6', python_callable=fn0) op7 = DummyOperator(task_id='op7') op8 = PythonOperator(task_id='op8', python_callable=fn2) op2 >> op5 >> [op6, op7] >> op8 op1_matcher = ClassTaskMatcher(DummyOperator) op2_matcher = ClassTaskMatcher(DummyOperator) op3_matcher = PythonCallTaskMatcher(print) op4_matcher = ClassTaskMatcher(DummyOperator) op1_matcher >> [op2_matcher, op3_matcher] >> op4_matcher op5_matcher = PythonCallTaskMatcher(math.pow) op6_matcher = ClassTaskMatcher(PythonOperator) op7_matcher = ClassTaskMatcher(DummyOperator) op8_matcher = PythonCallTaskMatcher(math.factorial) op5_matcher >> [op6_matcher, op7_matcher] >> op8_matcher expected_sub_dag = nx.DiGraph([(op5, op6), (op5, op7), (op6, op8), (op7, op8)]) dag_dg, found_sub_dags = TransformerUtils.find_sub_dag( dag, [op5_matcher]) self.assertEqual(len(found_sub_dags), 1) diff_dg: nx.DiGraph = nx.symmetric_difference(expected_sub_dag, found_sub_dags[0]) self.assertEqual(len(diff_dg.edges), 0) expected_sub_dag = nx.DiGraph([(op1, op2), (op1, op3), (op2, op4), (op3, op4)]) dag_dg, found_sub_dags = TransformerUtils.find_sub_dag( dag, [op1_matcher]) self.assertEqual(len(found_sub_dags), 1) diff_dg: nx.DiGraph = nx.symmetric_difference(expected_sub_dag, found_sub_dags[0]) self.assertEqual(len(diff_dg.edges), 0)
def transform(self, src_operator: BaseOperator, parent_fragment: DAGFragment, upstream_fragments: List[DAGFragment]) -> DAGFragment: TestTransformer3.livy_batch_op = TransformerUtils.find_op_in_fragment_list_strict( upstream_fragments, operator_type=LivyBatchOperator) return super(TestTransformer3, self).transform(src_operator, parent_fragment, upstream_fragments)
def transform(self, subdag: nx.DiGraph, parent_fragment: DAGFragment) -> DAGFragment: transformer = EmrCreateJobFlowOperatorTransformer( self.dag, self.defaults) return transformer.transform( TransformerUtils.find_matching_tasks( subdag, ClassTaskMatcher(EmrCreateJobFlowOperator))[0], parent_fragment)
def transform(self, src_operator: BaseOperator, parent_fragment: DAGFragment, upstream_fragments: List[DAGFragment]) -> DAGFragment: TestTransformer4.livy_sensor_op = TransformerUtils.find_op_in_parent_fragment_chain( parent_fragment, operator_type=LivyBatchSensor) return super(TestTransformer4, self).transform(src_operator, parent_fragment, upstream_fragments)
def transform(self, src_operator: BaseOperator, parent_fragment: DAGFragment, upstream_fragments: List[DAGFragment]) -> DAGFragment: """ This transformer assumes and relies on the fact that an upstream transformation of a :class:`~airflow.contrib.operators.emr_create_job_flow_operator.EmrCreateJobFlowOperator` has already taken place, since it needs to find the output of that transformation to get the `cluster_name` and `azure_conn_id` from that operator (which should have been a :class:`~airflowhdi.operators.AzureHDInsightCreateClusterOperator`) Creates a :class:`~airflowhdi.sensors.AzureHDInsightClusterSensor` in non-provisioning mode to monitor the cluster till it reaches a terminal state (cluster shutdown by user or failed). .. warning:: We do not have a way to tell the HDInsight cluster to halt if a job has failed, unlike EMR. So the cluster will continue to run even on job failure. You have to add a terminate cluster operator on step failure through ditto itself. """ create_op_task_id = TransformerUtils.get_task_id_from_xcom_pull( src_operator.job_flow_id) create_op: BaseOperator = \ TransformerUtils.find_op_in_fragment_list( upstream_fragments, operator_type=ConnectedAzureHDInsightCreateClusterOperator, task_id=create_op_task_id) if not create_op: raise UpstreamOperatorNotFoundException( ConnectedAzureHDInsightCreateClusterOperator, EmrJobFlowSensor) monitor_cluster_op = AzureHDInsightClusterSensor( create_op.cluster_name, azure_conn_id=create_op.azure_conn_id, poke_interval=5, task_id=f"{create_op.task_id}_monitor_cluster", dag=self.dag) self.copy_op_attrs(monitor_cluster_op, src_operator) self.sign_op(monitor_cluster_op) return DAGFragment([monitor_cluster_op])
def test_find_op_in_dag_fragment(self): with DAG(dag_id='d1', default_args=DEFAULT_DAG_ARGS) as dag: op1 = DummyOperator(task_id='d1t1') op2 = DummyOperator(task_id='d1t2') op3 = PythonOperator(task_id='d1t3', python_callable=print) op4 = DummyOperator(task_id='d1t4') op1 >> [op2, op3] >> op4 self.assertEqual( op3, TransformerUtils.find_op_in_dag_fragment( DAGFragment([op1]), PythonOperator))
def assert_dags_equals(test_case: unittest.TestCase, expected_dag: DAG, actual_dag: DAG): expected_dg = TransformerUtils.get_digraph_from_airflow_dag( expected_dag) actual_dg = TransformerUtils.get_digraph_from_airflow_dag(actual_dag) test_case.assertEqual(len(expected_dg.nodes), len(actual_dg.nodes)) def node_matcher(mismatches, n1, n2): i: BaseOperator = n1['op'] j: BaseOperator = n2['op'] diff = deepdiff.DeepDiff( i, j, exclude_paths=['root._dag', 'root.azure_hook.client'], exclude_regex_paths=[ re.compile(".*hook.*conn.*"), re.compile(".*{}.*".format( Transformer.TRANSFORMED_BY_HEADER)) ], view='tree', verbose_level=10) if diff != {}: mismatches.append((i, j, diff)) return diff == {} mismatches = [] is_isomorphic = nx.is_isomorphic(expected_dg, actual_dg, node_match=functools.partial( node_matcher, mismatches)) if not is_isomorphic: for i, j, diff in mismatches: print(f"Operators {i} and {j} are not equal") pprint(diff, indent=2) test_case.assertTrue(is_isomorphic)
def transform(self, src_operator: BaseOperator, parent_fragment: DAGFragment, upstream_fragments: List[DAGFragment]) -> DAGFragment: """ This transformer assumes and relies on the fact that an upstream transformation of a :class:`~airflow.contrib.operators.emr_create_job_flow_operator.EmrCreateJobFlowOperator` has already taken place, since it needs to find the output of that transformation to get the `cluster_name` and `azure_conn_id` from that operator (which should have been a :class:`~airflowhdi.operators.AzureHDInsightCreateClusterOperator`) Creates a :class:`~airflowhdi.operators.AzureHDInsightDeleteClusterOperator` to terminate the cluster """ create_op_task_id = TransformerUtils.get_task_id_from_xcom_pull( src_operator.job_flow_id) create_op: BaseOperator = \ TransformerUtils.find_op_in_fragment_list( upstream_fragments, operator_type=ConnectedAzureHDInsightCreateClusterOperator, task_id=create_op_task_id) if not create_op: raise UpstreamOperatorNotFoundException( ConnectedAzureHDInsightCreateClusterOperator, EmrTerminateJobFlowOperator) emr_terminate_op: EmrTerminateJobFlowOperator = src_operator terminate_cluster_op = AzureHDInsightDeleteClusterOperator( task_id=emr_terminate_op.task_id, azure_conn_id=create_op.azure_conn_id, cluster_name=create_op.cluster_name, dag=self.dag) self.copy_op_attrs(terminate_cluster_op, src_operator) self.sign_op(terminate_cluster_op) terminate_cluster_op.trigger_rule = TriggerRule.ALL_DONE return DAGFragment([terminate_cluster_op])
def test_get_digraph_from_airflow_dag(self): with DAG(dag_id='d1', default_args=DEFAULT_DAG_ARGS) as dag: op1 = DummyOperator(task_id='op1') op2 = DummyOperator(task_id='op2') op3 = DummyOperator(task_id='op3') op4 = DummyOperator(task_id='op4') op1 >> [op2, op3] >> op4 expected_dg = nx.DiGraph() expected_dg.add_edge(op1, op2) expected_dg.add_edge(op1, op3) expected_dg.add_edge(op2, op4) expected_dg.add_edge(op3, op4) actual_dg = TransformerUtils.get_digraph_from_airflow_dag(dag) diff_dg: nx.DiGraph = nx.symmetric_difference(expected_dg, actual_dg) self.assertEqual(len(diff_dg.edges), 0)
def draw_dag_graphiviz_rendering(dag: DAG, colorer=ut_colorer, relabeler=ut_relabeler, legender=None, figsize=[6.4, 4.8], legend_own_figure=False): dg = TransformerUtils.get_digraph_from_airflow_dag(dag) labels = {} if relabeler: labels = relabeler(dg) color_map = [] if colorer: color_map = colorer(dg) dg.graph.setdefault('graph', {})['rankdir'] = 'LR' dg.graph.setdefault('graph', {})['newrank'] = 'true' plt.figure(figsize=figsize) plt.title(dag.dag_id) pos = graphviz_layout(dg, prog='dot', args='-Gnodesep=0.1') rads = random.uniform(0.05, 0.1) nx.draw_networkx(dg, pos=pos, labels=labels, font_size=8, node_color=color_map, node_size=900, font_color='white', font_weight='bold', connectionstyle=f"arc3, rad={rads}") if legender: if legend_own_figure: plt.figure() plt.title(dag.dag_id) plt.rcParams["legend.fontsize"] = 8 plt.legend(handles=legender(dg), ncol=2) else: plt.rcParams["legend.fontsize"] = 7 plt.legend(handles=legender(dg), borderaxespad=0.9, ncol=2, loc='lower center')
def transform(self, src_operator: BaseOperator, parent_fragment: DAGFragment, upstream_fragments: List[DAGFragment]) -> DAGFragment: """ This transformer assumes and relies on the fact that an upstream transformation of a :class:`~airflow.contrib.operators.emr_create_job_flow_operator.EmrCreateJobFlowOperator` has already taken place, since it needs to find the output of that transformation to get the `cluster_name` and `azure_conn_id` from that operator (which should have been a :class:`~airflowhdi.operators.AzureHDInsightCreateClusterOperator`) This transformer also requires than there would already be transformations of :class:`~airflow.contrib.operators.emr_add_steps_operator.EmrAddStepsOperator` to :class:`~airflowhdi.operators.LivyBatchOperator` or :class:`~airflowhdi.operators.AzureHDInsightSshOperator` in the `upstream_fragments` which can then be monitored by the output tasks of this transformer. It needs to search for those ops upstream to find their task IDs Adds :class:`~airflowhdi.sensors.LivyBatchSensor` if it was a livy spark job. There's no sensor required for a transformed :class:`~airflowhdi.operators.AzureHDInsightSshOperator` as it is synchronous. """ create_op_task_id = TransformerUtils.get_task_id_from_xcom_pull( src_operator.job_flow_id) create_op: BaseOperator = \ TransformerUtils.find_op_in_fragment_list( upstream_fragments, operator_type=ConnectedAzureHDInsightCreateClusterOperator, task_id=create_op_task_id) if not create_op: raise UpstreamOperatorNotFoundException( ConnectedAzureHDInsightCreateClusterOperator, EmrStepSensor) emr_step_sensor_op: EmrStepSensor = src_operator emr_add_step_task_id = TransformerUtils.get_task_id_from_xcom_pull( emr_step_sensor_op.step_id) emr_add_step_step_id = TransformerUtils.get_list_index_from_xcom_pull( emr_step_sensor_op.step_id) target_step_task_id = EmrAddStepsOperatorTransformer.get_target_step_task_id( emr_add_step_task_id, emr_add_step_step_id) add_step_op: BaseOperator = \ TransformerUtils.find_op_in_fragment_list_strict( upstream_fragments, task_id=target_step_task_id) if isinstance(add_step_op, LivyBatchOperator): step_sensor_op = LivyBatchSensor( batch_id= f"{{{{ task_instance.xcom_pull('{target_step_task_id}', key='return_value') }}}}", task_id=emr_step_sensor_op.task_id, azure_conn_id=create_op.azure_conn_id, cluster_name=create_op.cluster_name, verify_in="yarn", dag=self.dag) else: # don't need a sensor for the ssh operator step_sensor_op = DummyOperator(task_id=emr_step_sensor_op.task_id, dag=self.dag) self.copy_op_attrs(step_sensor_op, emr_step_sensor_op) self.sign_op(step_sensor_op) return DAGFragment([step_sensor_op])
def transform(self, src_operator: BaseOperator, parent_fragment: DAGFragment, upstream_fragments: List[DAGFragment]) -> DAGFragment: """ This transformer assumes and relies on the fact that an upstream transformation of a :class:`~airflow.contrib.operators.emr_create_job_flow_operator.EmrCreateJobFlowOperator` has already taken place, since it needs to find the output of that transformation to get the `cluster_name` and `azure_conn_id` from that operator (which should have been a :class:`~airflowhdi.operators.AzureHDInsightCreateClusterOperator`) It then goes through the EMR steps of this :class:`~airflow.contrib.operators.emr_add_steps_operator.EmrAddStepsOperator` and creates a :class:`~airflowhdi.operators.LivyBatchOperator` or an :class:`~airflowhdi.operators.AzureHDInsightSshOperator` for each corresponding step, based on grokking the step's params and figuring out whether its a spark job being run on an arbitrary hadoop command like `distcp`, `hdfs` or the like. .. note:: This transformer creates multiple operators from a single source operator .. note:: The spark configuration for the livy spark job are derived from `step['HadoopJarStep']['Properties']` of the EMR step, or could even be specified at the cluster level itself when transforming the job flow """ create_op_task_id = TransformerUtils.get_task_id_from_xcom_pull(src_operator.job_flow_id) create_op: BaseOperator = \ TransformerUtils.find_op_in_fragment_list( upstream_fragments, operator_type=ConnectedAzureHDInsightCreateClusterOperator, task_id=create_op_task_id) if not create_op: raise UpstreamOperatorNotFoundException(ConnectedAzureHDInsightCreateClusterOperator, EmrAddStepsOperator) emr_add_steps_op: EmrAddStepsOperator = src_operator dag_fragment_steps = [] steps_added_op = DummyOperator( task_id=f"{emr_add_steps_op.task_id}_added", dag=self.dag) self.sign_op(steps_added_op) for step in emr_add_steps_op.steps: name = step['Name'] ssh_command = None livy_file = None livy_arguments = None livy_main_class = None if 'command-runner' in step['HadoopJarStep']['Jar']: command_runner_cmd = step['HadoopJarStep']['Args'] if '/usr/bin/spark-submit' in command_runner_cmd[0]: livy_file = command_runner_cmd[1] livy_arguments = command_runner_cmd[2:] elif 's3-dist-cp' in command_runner_cmd[0]: src = None dest = None for arg in command_runner_cmd[1:]: if arg.startswith('--src='): src = arg.split("--src=", 1)[1] if arg.startswith('--dest='): dest = arg.split("--dest=", 1)[1] mappers = EmrAddStepsOperatorTransformer.HADOOP_DISTCP_DEFAULT_MAPPERS ssh_command = f"hadoop distcp -m {mappers} {src} {dest}" elif 'hdfs' in command_runner_cmd[0]: ssh_command = " ".join(command_runner_cmd) else: raise Exception("This kind of step is not supported right now", command_runner_cmd[0]) else: livy_file = step['HadoopJarStep']['Jar'] livy_arguments = step['HadoopJarStep']['Args'] livy_main_class = step['HadoopJarStep'].get('MainClass', None) if 'Properties' in step['HadoopJarStep']: properties = "" for key, val in step['HadoopJarStep']['Properties']: properties += f"-D{key}={val} " self.spark_conf['spark.executor.extraJavaOptions'] = properties self.spark_conf['spark.driver.extraJavaOptions'] = properties target_step_task_id = EmrAddStepsOperatorTransformer.get_target_step_task_id(emr_add_steps_op.task_id, emr_add_steps_op.steps.index(step)) if ssh_command is not None: step_op = AzureHDInsightSshOperator( cluster_name=create_op.cluster_name, azure_conn_id=create_op.azure_conn_id, command=ssh_command, task_id=target_step_task_id, dag=self.dag ) else: step_op = LivyBatchOperator( name=name, file=livy_file, arguments=livy_arguments, class_name=livy_main_class, azure_conn_id=create_op.azure_conn_id, cluster_name=create_op.cluster_name, proxy_user=self.proxy_user, conf=self.spark_conf, task_id=target_step_task_id, dag=self.dag ) self.copy_op_attrs(step_op, emr_add_steps_op) self.sign_op(step_op) step_op.trigger_rule = TriggerRule.ALL_SUCCESS step_op.set_downstream(steps_added_op) dag_fragment_steps.append(step_op) return DAGFragment(dag_fragment_steps)
def transform_operators(self, src_dag: DAG): """ Transforms the operators in the source DAG and creates the target DAG out of the returned :class:`~ditto.api.DAGFragment`\'s. Finds the transformers by running each operator through all the resolvers passed. Does a bread-first-traversal on the source DAG such that the result of the transformation of upstream (and previous ops in this level) are available to downstream transformers in level-order. This is helpful for real world use cases of transformation like having a spark step op transformer read the result of the transformation of a cluster create op transformer. Caches the results of transformations to avoid repeat work, as this is a graph, not a tree being traversed. .. note:: Stitches the final target DAG after having transformed all operators. :param src_dag: the source airflow DAG to be operator-transformed :return: does not return anything. mutates ``self.target_dag`` directly """ src_task_q: "Queue[(BaseOperator,DAGFragment)]" = Queue() for root in src_dag.roots: src_task_q.put((root, None)) # a list representing all processed fragments so far # transformers can use this to fetch information from the # level before or even tasks before this one in the same level # we'll also use this to stetch our final airflow DAG together transformed_dag_fragments = [] while not src_task_q.empty(): src_task, parent_fragment = src_task_q.get() # since this is a graph, we can encounter the same source # task repeatedly if it has multiple parents in the src_dag # to avoid transforming it repeatedly, check if has already been seen task_dag_fragment = None cached_fragment = False if src_task in self.transformer_cache: log.info("Already transformed source task: %s", src_task) task_dag_fragment = self.transformer_cache[src_task] cached_fragment = True else: # get transformer class for this operator transformer_cl = None for resolver in self.transformer_resolvers: transformer_cl = resolver.resolve_transformer(src_task) if transformer_cl: log.info( f"Found transformer for operator {src_task.__class__.__name__}" f": {transformer_cl.__name__} using {resolver.__class__.__name__}" ) break if not transformer_cl: transformer_cl = CopyTransformer # get transformer defaults for this operator transformer_defaults = None if self.transformer_defaults is not None: if transformer_cl in self.transformer_defaults.defaults: transformer_defaults = self.transformer_defaults.defaults[ transformer_cl] # create the transformer transformer = transformer_cl(self.target_dag, transformer_defaults) # do transformation, and get DAGFragment task_dag_fragment = transformer.transform( src_task, parent_fragment, transformed_dag_fragments) self.transformer_cache[src_task] = task_dag_fragment # add this transformed output fragment to # the upstream fragments processed so far transformed_dag_fragments.append(task_dag_fragment) # add children to queue if src_task.downstream_list: for downstream_task in src_task.downstream_list: src_task_q.put((downstream_task, task_dag_fragment)) # chain it to the parent if parent_fragment: task_dag_fragment.add_parent(parent_fragment) # convert dag fragment relationships to airflow dag relationships # for the processed fragments (which are now available in topological # sorted order) for output_fragment in transformed_dag_fragments: # get a flattened list of roots for the child DAGFragments all_child_fragment_roots = [ step for frag in output_fragment.children for step in frag.tasks ] # attach the flattened roots of child DAGFragments to this DAGFragment TransformerUtils.add_downstream_dag_fragment( output_fragment, DAGFragment(all_child_fragment_roots))
def transform_sub_dags(self, src_dag: DAG): """ Transforms the subdags of the source DAG, as matched by the :class:`~ditto.api.TaskMatcher` DAG provided by the :class:`~ditto.api.SubDagTransformer`\'s :meth:`~ditto.api.SubDagTransformer.get_sub_dag_matcher` method. Multiple subdag transformers can run through the source DAG, and each of them can match+transform multiple subdags of the source DAG, _and_ each such transformation can return multiple subdags as a result, so this can get quite flexible if you want. The final DAG is carefully stitched with all the results of the subdag transformations. See the unit tests at `test_core.py` for complex examples. .. note:: If your matched input subdag had different leaves pointing to different operators/nodes, the transformed subdags leaves will just get multiplexed to all the leaves of the `source DAG`, since it is not possible to know which new leaf is to be stitched to which node of the source DAG, and resolve new relationships based on old ones. .. warning:: Make sure that you don't provide :class:`~ditto.api.SubDagTransformer`\'s which with overlapping subdag matchers, otherwise things can understandably get messy. .. seealso:: The core logic behind this method lies in a graph algorithm called subgraph isomorphism, and is explained in detail at :meth:`~ditto.utils.TransformerUtils.find_sub_dag` :param src_dag: the source airflow DAG to be subdag-transformed :return: does not return anything. mutates the passed ``src_dag`` directly, which is why you should pass a copy of the source DAG. """ for subdag_transformer_cl in self.subdag_transformers: transformer_defaults = None if self.transformer_defaults is not None: if subdag_transformer_cl in self.transformer_defaults.defaults: transformer_defaults = self.transformer_defaults.defaults[ subdag_transformer_cl] subdag_transformer = subdag_transformer_cl(src_dag, transformer_defaults) matcher_roots = subdag_transformer.get_sub_dag_matcher() # find matching sub-dags usng the [TaskMatcher] DAG src_dag_dg, subdags = TransformerUtils.find_sub_dag( src_dag, matcher_roots) src_dag_nodes = [t for t in src_dag_dg.nodes] # transform each matching sub-dag and replace it in the DAG cloned_subdags = copy.deepcopy(subdags) # deep copy since DiGraph holds weak refs and that creates a problem # with traversing the DiGraph after deleting nodes from the original airflow DAG for subdag, cloned_subdag in zip(subdags, cloned_subdags): # upstream tasks are nodes in the main dag in-edges of the nodes in this sub-dag # which do not belong to this sub-dag subdag_upstream_tasks = set(n for edge in src_dag_dg.in_edges(nbunch=subdag.nodes) \ for n in edge if n not in subdag) # downstream tasks are nodes in the main dag out-edges of the nodes in this sub-dag # which do not belong to this sub-dag subdag_downstream_tasks = set(n for edge in src_dag_dg.edges(nbunch=subdag.nodes) \ for n in edge if n not in subdag) subdag_nodes = [n for n in subdag.nodes] for task in subdag_nodes: TransformerUtils.remove_task_from_dag( src_dag, src_dag_nodes, task) new_subdag_fragment = subdag_transformer.transform( cloned_subdag, DAGFragment(subdag_upstream_tasks)) # attach new subdag to upstream if subdag_upstream_tasks: for parent in subdag_upstream_tasks: for new_root in new_subdag_fragment.tasks: parent.set_downstream(new_root) # assign new subdag to src_dag TransformerUtils.add_dag_fragment_to_dag( src_dag, new_subdag_fragment) # attach downstream to the leaves of the new subdag TransformerUtils.add_downstream_dag_fragment( new_subdag_fragment, DAGFragment(subdag_downstream_tasks))
def test_get_step_id_from_xcom_pull(self, test, result): self.assertEqual(result, TransformerUtils.get_list_index_from_xcom_pull(test))
def test_get_task_id_from_xcom_pull(self, test, result): self.assertEqual(result, TransformerUtils.get_task_id_from_xcom_pull(test))