def test_xcom_enable_pickle_type(self): json_obj = {"key": "value"} execution_date = timezone.utcnow() key = "xcom_test2" dag_id = "test_dag2" task_id = "test_task2" configuration.set("core", "enable_xcom_pickling", "True") XCom.set(key=key, value=json_obj, dag_id=dag_id, task_id=task_id, execution_date=execution_date) ret_value = XCom.get_one(key=key, dag_id=dag_id, task_id=task_id, execution_date=execution_date) self.assertEqual(ret_value, json_obj) session = settings.Session() ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.execution_date == execution_date ).first().value self.assertEqual(ret_value, json_obj)
def test_xcom_get_many(self): json_obj = {"key": "value"} execution_date = timezone.utcnow() key = "xcom_test4" dag_id1 = "test_dag4" task_id1 = "test_task4" dag_id2 = "test_dag5" task_id2 = "test_task5" configuration.set("core", "xcom_enable_pickling", "True") XCom.set(key=key, value=json_obj, dag_id=dag_id1, task_id=task_id1, execution_date=execution_date) XCom.set(key=key, value=json_obj, dag_id=dag_id2, task_id=task_id2, execution_date=execution_date) results = XCom.get_many(key=key, execution_date=execution_date) for result in results: self.assertEqual(result.value, json_obj)
def get_link(self, operator: BaseOperator, dttm: datetime) -> str: conf = XCom.get_one( key=DataflowJobLink.key, dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, ) return ( DATAFLOW_JOB_LINK.format( project_id=conf["project_id"], region=conf['region'], job_id=conf['job_id'] ) if conf else "" )
def get_link(self, operator: BaseOperator, dttm: datetime): list_conf = XCom.get_one( key=DataprocListLink.key, dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, ) return ( list_conf["url"].format( project_id=list_conf["project_id"], ) if list_conf else "" )
def test_write_xcom(self): url_template = ('/api/experimental/dags/{}/tasks/{}' '/instances/{}/xcom/{}/{}') dag_id = 'example_bash_operator' task_id = 'xcom_task' key = 'xcom_key' value = 'xcom_value' execution_date = datetime.now().replace(microsecond=0) datetime_string = execution_date.isoformat() # Test Correct execution response = self.app.post( quote(url_template.format(dag_id, task_id, datetime_string, key, value)) ) self.assertEqual(200, response.status_code) # Check xcom value in database xcom_value = XCom.get_one(execution_date=execution_date, key=key, task_id=task_id, dag_id=dag_id) self.assertEqual(value, xcom_value, 'XCom value: expected {}, found {}'.format(value, xcom_value)) # Test error for nonexistent dag response = self.app.post( quote(url_template.format('does_not_exist_dag', task_id, datetime_string, key, value)) ) self.assertEqual(404, response.status_code) # Test error for bad datetime format response = self.app.post( quote(url_template.format(dag_id, task_id, 'not_a_datetime', key, value)) ) self.assertEqual(400, response.status_code)
def get_link( self, operator, dttm: Optional[datetime] = None, ti_key: Optional["TaskInstanceKey"] = None, ) -> str: """ Get link to EMR cluster. :param operator: operator :param dttm: datetime :return: url link """ if ti_key: flow_id = XCom.get_value(key="return_value", ti_key=ti_key) else: assert dttm flow_id = XCom.get_one(key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm) return ( f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}' if flow_id else '')
def get_link(self, operator: BaseOperator, dttm: datetime) -> str: """ Get link to EMR cluster. :param operator: operator :param dttm: datetime :return: url link """ flow_id = XCom.get_one(key="return_value", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm) return ( f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}' if flow_id else '')
def test_pull_parameters_by_task(env, create_dag, create_df_task, push_xcom): dag = create_dag # df_task = create_df_task(dag) push_task = PythonOperator(task_id='three', python_callable=push_xcom, op_args=['three_value'], dag=dag, provide_context=True) push_ti = TaskInstance(task=push_task, execution_date=datetime.now()) push_task.execute(push_ti.get_template_context()) xcom = XCom.get_many(execution_date=push_ti.execution_date, dag_ids=[dag.dag_id], task_ids=[push_task.task_id]) pprint(xcom[0]) assert len(xcom) == 1 assert xcom[0].value == 'three_value' assert False
def test_pull_parameters_by_specific_key(env, create_dag, create_df_task, push_xcom_key): dag = create_dag df_task = create_df_task(dag) push_task = PythonOperator(task_id='two', python_callable=push_xcom_key, op_args=['two', 'two_value'], dag=dag, provide_context=True) push_ti = TaskInstance(task=push_task, execution_date=datetime.now()) push_task.execute(push_ti.get_template_context()) xcom = XCom.get_many(execution_date=push_ti.execution_date, dag_ids=[dag.dag_id]) assert len(xcom) == 1 assert xcom[0].value == 'two_value' merged = df_task.merge_parameters(push_ti) assert merged['two-specific-key'] == 'two_value'
def test_serialize(self, session): xcom_model = XCom( key='test_key', timestamp=self.default_time_parsed, execution_date=self.default_time_parsed, task_id='test_task_id', dag_id='test_dag', ) session.add(xcom_model) session.commit() xcom_model = session.query(XCom).first() deserialized_xcom = xcom_collection_item_schema.dump(xcom_model) assert deserialized_xcom == { 'key': 'test_key', 'timestamp': self.default_time, 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', }
def _get_task_instance_xcom_dict(dag_id, task_id, execution_date): try: results = XCom.get_many(dag_ids=dag_id, task_ids=task_id, execution_date=execution_date) if not results: return {} xcom_dict = {xcom.key: str(xcom.value) for xcom in results} sliced_xcom = (dict( itertools.islice(xcom_dict.items(), MAX_XCOM_LENGTH)) if len(xcom_dict) > MAX_XCOM_LENGTH else xcom_dict) for key, value in six.iteritems(sliced_xcom): sliced_xcom[key] = shorten_xcom_value(value) return sliced_xcom except Exception as e: logging.info("Failed to get xcom dict. Exception: {}".format(e)) return {}
def get_link(self, operator: BaseOperator, dttm: datetime) -> str: """ Get link to qubole command result page. :param operator: operator :param dttm: datetime :return: url link """ conn = BaseHook.get_connection( getattr(operator, "qubole_conn_id", None) or operator.kwargs['qubole_conn_id'] # type: ignore[attr-defined] ) if conn and conn.host: host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host) else: host = 'https://api.qubole.com/v2/analyze?command_id=' qds_command_id = XCom.get_one( key='qbol_cmd_id', dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm ) url = host + str(qds_command_id) if qds_command_id else '' return url
def test_write_xcom(self): url_template = ('/api/experimental/dags/{}/tasks/{}' '/instances/{}/xcom/{}/{}') dag_id = 'example_bash_operator' task_id = 'xcom_task' key = 'xcom_key' value = 'xcom_value' execution_date = datetime.now().replace(microsecond=0) datetime_string = execution_date.isoformat() # Test Correct execution response = self.app.post( quote( url_template.format(dag_id, task_id, datetime_string, key, value))) self.assertEqual(200, response.status_code) # Check xcom value in database xcom_value = XCom.get_one(execution_date=execution_date, key=key, task_id=task_id, dag_id=dag_id) self.assertEqual( value, xcom_value, 'XCom value: expected {}, found {}'.format(value, xcom_value)) # Test error for nonexistent dag response = self.app.post( quote( url_template.format('does_not_exist_dag', task_id, datetime_string, key, value))) self.assertEqual(404, response.status_code) # Test error for bad datetime format response = self.app.post( quote( url_template.format(dag_id, task_id, 'not_a_datetime', key, value))) self.assertEqual(400, response.status_code)
def test_serialize(self, session): xcom_model = XCom( key='test_key', timestamp=self.default_time_parsed, execution_date=self.default_time_parsed, task_id='test_task_id', dag_id='test_dag', value=b'test_binary', ) session.add(xcom_model) session.commit() xcom_model = session.query(XCom).first() deserialized_xcom = xcom_schema.dump(xcom_model) self.assertEqual( deserialized_xcom, { 'key': 'test_key', 'timestamp': self.default_time, 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', 'value': 'test_binary', })
def get_link(self, operator, dttm): run_id = XCom.get_one( key="run_id", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, ) conn = BaseHook.get_connection(operator.azure_data_factory_conn_id) subscription_id = conn.extra_dejson[ "extra__azure_data_factory__subscriptionId"] # Both Resource Group Name and Factory Name can either be declared in the Azure Data Factory # connection or passed directly to the operator. resource_group_name = operator.resource_group_name or conn.extra_dejson.get( "extra__azure_data_factory__resource_group_name") factory_name = operator.factory_name or conn.extra_dejson.get( "extra__azure_data_factory__factory_name") url = ( f"https://adf.azure.com/en-us/monitoring/pipelineruns/{run_id}" f"?factory=/subscriptions/{subscription_id}/" f"resourceGroups/{resource_group_name}/providers/Microsoft.DataFactory/" f"factories/{factory_name}") return url