Exemplo n.º 1
0
    def test_console_extra_link_serialized_field(self):
        with self.dag:
            training_op = MLEngineStartTrainingJobOperator(**self.TRAINING_DEFAULT_ARGS)
        serialized_dag = SerializedDAG.to_dict(self.dag)
        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict[self.TRAINING_DEFAULT_ARGS['task_id']]

        # Check Serialized version of operator link
        self.assertEqual(
            serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
            [{"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink": {}}]
        )

        # Check DeSerialized version of operator link
        self.assertIsInstance(list(simple_task.operator_extra_links)[0], AIPlatformConsoleLink)

        job_id = self.TRAINING_DEFAULT_ARGS['job_id']
        project_id = self.TRAINING_DEFAULT_ARGS['project_id']
        gcp_metadata = {
            "job_id": job_id,
            "project_id": project_id,
        }

        ti = TaskInstance(
            task=training_op,
            execution_date=DEFAULT_DATE,
        )
        ti.xcom_push(key='gcp_metadata', value=gcp_metadata)

        self.assertEqual(
            f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}",
            simple_task.get_extra_links(DEFAULT_DATE, AIPlatformConsoleLink.name),
        )

        self.assertEqual(
            '',
            simple_task.get_extra_links(datetime.datetime(2019, 1, 1), AIPlatformConsoleLink.name),
        )
Exemplo n.º 2
0
 def test_set_machine_type_with_templates(self, _):
     dag_id = 'test_dag_id'
     args = {
         'start_date': DEFAULT_DATE
     }
     self.dag = DAG(dag_id, default_args=args)  # pylint: disable=attribute-defined-outside-init
     op = ComputeEngineSetMachineTypeOperator(
         project_id='{{ dag.dag_id }}',
         zone='{{ dag.dag_id }}',
         resource_id='{{ dag.dag_id }}',
         body={},
         gcp_conn_id='{{ dag.dag_id }}',
         api_version='{{ dag.dag_id }}',
         task_id='id',
         dag=self.dag
     )
     ti = TaskInstance(op, DEFAULT_DATE)
     ti.render_templates()
     self.assertEqual(dag_id, getattr(op, 'project_id'))
     self.assertEqual(dag_id, getattr(op, 'zone'))
     self.assertEqual(dag_id, getattr(op, 'resource_id'))
     self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
     self.assertEqual(dag_id, getattr(op, 'api_version'))
Exemplo n.º 3
0
    def setUp(self):
        super().setUp()
        self.wasb_log_folder = 'wasb://container/remote/log/location'
        self.remote_log_location = 'remote/log/location/1.log'
        self.local_log_location = 'local/log/location'
        self.container_name = "wasb-container"
        self.filename_template = '{try_number}.log'
        self.wasb_task_handler = WasbTaskHandler(
            base_log_folder=self.local_log_location,
            wasb_log_folder=self.wasb_log_folder,
            wasb_container=self.container_name,
            filename_template=self.filename_template,
            delete_local_copy=True,
        )

        date = datetime(2020, 8, 10)
        self.dag = DAG('dag_for_testing_file_task_handler', start_date=date)
        task = DummyOperator(task_id='task_for_testing_file_log_handler',
                             dag=self.dag)
        self.ti = TaskInstance(task=task, execution_date=date)
        self.ti.try_number = 1
        self.ti.state = State.RUNNING
        self.addCleanup(self.dag.clear)
Exemplo n.º 4
0
    def test_poke_context(self, mock_session_send):
        response = requests.Response()
        response.status_code = 200
        mock_session_send.return_value = response

        def resp_check(_, execution_date):
            if execution_date == DEFAULT_DATE:
                return True
            raise AirflowException('AirflowException raised here!')

        task = HttpSensor(
            task_id='http_sensor_poke_exception',
            http_conn_id='http_default',
            endpoint='',
            request_params={},
            response_check=resp_check,
            timeout=5,
            poke_interval=1,
            dag=self.dag,
        )

        task_instance = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        task.execute(task_instance.get_template_context())
Exemplo n.º 5
0
    def test_error_sending_task(self):
        def fake_execute_command():
            pass

        with _prepare_app(execute=fake_execute_command):
            # fake_execute_command takes no arguments while execute_command takes 1,
            # which will cause TypeError when calling task.apply_async()
            executor = celery_executor.CeleryExecutor()
            task = BashOperator(task_id="test",
                                bash_command="true",
                                dag=DAG(dag_id='id'),
                                start_date=datetime.datetime.now())
            when = datetime.datetime.now()
            value_tuple = 'command', 1, None, \
                SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.datetime.now()))
            key = ('fail', 'fake_simple_ti', when, 0)
            executor.queued_tasks[key] = value_tuple
            executor.heartbeat()
        self.assertEqual(0, len(executor.queued_tasks),
                         "Task should no longer be queued")
        self.assertEqual(
            executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0],
            State.FAILED)
Exemplo n.º 6
0
    def test_bigquery_operator_extra_link_when_single_query(
            self, mock_hook, session):
        bigquery_task = BigQueryExecuteQueryOperator(
            task_id=TASK_ID,
            sql='SELECT * FROM test_table',
            dag=self.dag,
        )
        self.dag.clear()
        session.query(XCom).delete()

        ti = TaskInstance(
            task=bigquery_task,
            execution_date=DEFAULT_DATE,
        )

        job_id = '12345'
        ti.xcom_push(key='job_id', value=job_id)

        assert f'https://console.cloud.google.com/bigquery?j={job_id}' == bigquery_task.get_extra_links(
            DEFAULT_DATE, BigQueryConsoleLink.name)

        assert '' == bigquery_task.get_extra_links(datetime(2019, 1, 1),
                                                   BigQueryConsoleLink.name)
Exemplo n.º 7
0
def test_run_airflow_dag(scaffold_dag):
    '''This test runs the sample Airflow dag using the TaskInstance API, directly from Python'''
    _n, _p, _d, static_path, editable_path = scaffold_dag

    execution_date = datetime.datetime.utcnow()

    import_module_from_path('demo_pipeline_static__scaffold', static_path)
    demo_pipeline = import_module_from_path('demo_pipeline', editable_path)

    _dag, tasks = demo_pipeline.make_dag(
        dag_id=demo_pipeline.DAG_ID,
        dag_description=demo_pipeline.DAG_DESCRIPTION,
        dag_kwargs=dict(default_args=demo_pipeline.DEFAULT_ARGS, **demo_pipeline.DAG_KWARGS),
        s3_conn_id=demo_pipeline.S3_CONN_ID,
        modified_docker_operator_kwargs=demo_pipeline.MODIFIED_DOCKER_OPERATOR_KWARGS,
        host_tmp_dir=demo_pipeline.HOST_TMP_DIR,
    )

    # These are in topo order already
    for task in tasks:
        ti = TaskInstance(task=task, execution_date=execution_date)
        context = ti.get_template_context()
        task.execute(context)
Exemplo n.º 8
0
    def test_parse_bucket_key_from_jinja(self, mock_hook):
        mock_hook.return_value.check_for_key.return_value = False

        Variable.set("test_bucket_key", "s3://bucket/key")

        execution_date = datetime(2020, 1, 1)

        dag = DAG("test_s3_key", start_date=execution_date)
        op = S3KeySensor(
            task_id='s3_key_sensor',
            bucket_key='{{ var.value.test_bucket_key }}',
            bucket_name=None,
            dag=dag,
        )

        ti = TaskInstance(task=op, execution_date=execution_date)
        context = ti.get_template_context()
        ti.render_templates(context)

        op.poke(None)

        self.assertEqual(op.bucket_key, "key")
        self.assertEqual(op.bucket_name, "bucket")
Exemplo n.º 9
0
def test_parent_not_executed():
    """
    A simple DAG with a BranchPythonOperator that does not follow op2. Parent task is not yet
    executed (no xcom data). NotPreviouslySkippedDep is met (no decision).
    """
    start_date = pendulum.datetime(2020, 1, 1)
    dag = DAG("test_parent_not_executed_dag",
              schedule_interval=None,
              start_date=start_date)
    op1 = BranchPythonOperator(task_id="op1",
                               python_callable=lambda: "op3",
                               dag=dag)
    op2 = DummyOperator(task_id="op2", dag=dag)
    op3 = DummyOperator(task_id="op3", dag=dag)
    op1 >> [op2, op3]

    ti2 = TaskInstance(op2, start_date)

    with create_session() as session:
        dep = NotPreviouslySkippedDep()
        assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 0
        assert dep.is_met(ti2, session)
        assert ti2.state == State.NONE
Exemplo n.º 10
0
def task_failed_deps(args):
    """
    Returns the unmet dependencies for a task instance from the perspective of the
    scheduler (i.e. why a task instance doesn't get scheduled and then queued by the
    scheduler, and then run by an executor).

    >>> airflow task_failed_deps tutorial sleep 2015-01-01
    Task instance dependencies not met:
    Dagrun Running: Task instance's dagrun did not exist: Unknown reason
    Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks to have succeeded, but found 1 non-success(es).
    """
    dag = get_dag(args)
    task = dag.get_task(task_id=args.task_id)
    ti = TaskInstance(task, args.execution_date)

    dep_context = DepContext(deps=SCHEDULER_DEPS)
    failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
    if failed_deps:
        print("Task instance dependencies not met:")
        for dep in failed_deps:
            print("{}: {}".format(dep.dep_name, dep.reason))
    else:
        print("Task instance dependencies are all met.")
Exemplo n.º 11
0
 def test_instance_start_with_templates(self, _):
     dag_id = 'test_dag_id'
     configuration.load_test_config()
     args = {
         'start_date': DEFAULT_DATE
     }
     self.dag = DAG(dag_id, default_args=args)
     op = GceInstanceStartOperator(
         project_id='{{ dag.dag_id }}',
         zone='{{ dag.dag_id }}',
         resource_id='{{ dag.dag_id }}',
         gcp_conn_id='{{ dag.dag_id }}',
         api_version='{{ dag.dag_id }}',
         task_id='id',
         dag=self.dag
     )
     ti = TaskInstance(op, DEFAULT_DATE)
     ti.render_templates()
     self.assertEqual(dag_id, getattr(op, 'project_id'))
     self.assertEqual(dag_id, getattr(op, 'zone'))
     self.assertEqual(dag_id, getattr(op, 'resource_id'))
     self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
     self.assertEqual(dag_id, getattr(op, 'api_version'))
Exemplo n.º 12
0
 def test_templates(self, _):
     dag_id = 'test_dag_id'
     configuration.load_test_config()
     args = {'start_date': DEFAULT_DATE}
     self.dag = DAG(dag_id, default_args=args)
     op = GoogleCloudStorageToGoogleCloudStorageTransferOperator(
         source_bucket='{{ dag.dag_id }}',
         destination_bucket='{{ dag.dag_id }}',
         description='{{ dag.dag_id }}',
         object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']},
         gcp_conn_id='{{ dag.dag_id }}',
         task_id=TASK_ID,
         dag=self.dag,
     )
     ti = TaskInstance(op, DEFAULT_DATE)
     ti.render_templates()
     self.assertEqual(dag_id, getattr(op, 'source_bucket'))
     self.assertEqual(dag_id, getattr(op, 'destination_bucket'))
     self.assertEqual(dag_id, getattr(op, 'description'))
     self.assertEqual(
         dag_id,
         getattr(op, 'object_conditions')['exclude_prefixes'][0])
     self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
Exemplo n.º 13
0
    def test_bigquery_operator_defaults(self, mock_hook):
        operator = BigQueryOperator(
            task_id=TASK_ID,
            sql='Select * from test_table',
            dag=self.dag,
            default_args=self.args,
            schema_update_options=None
        )

        operator.execute(MagicMock())
        mock_hook.return_value \
            .get_conn.return_value \
            .cursor.return_value \
            .run_query \
            .assert_called_once_with(
                sql='Select * from test_table',
                destination_dataset_table=None,
                write_disposition='WRITE_EMPTY',
                allow_large_results=False,
                flatten_results=None,
                udf_config=None,
                maximum_billing_tier=None,
                maximum_bytes_billed=None,
                create_disposition='CREATE_IF_NEEDED',
                schema_update_options=None,
                query_params=None,
                labels=None,
                priority='INTERACTIVE',
                time_partitioning=None,
                api_resource_configs=None,
                cluster_fields=None,
                encryption_configuration=None
            )
        self.assertTrue(isinstance(operator.sql, str))
        ti = TaskInstance(task=operator, execution_date=DEFAULT_DATE)
        ti.render_templates()
        self.assertTrue(isinstance(ti.task.sql, str))
Exemplo n.º 14
0
Arquivo: cli.py Projeto: rijuk/airflow
def backfill(args):
    logging.basicConfig(level=settings.LOGGING_LEVEL,
                        format=settings.SIMPLE_LOG_FORMAT)
    dagbag = DagBag(process_subdir(args.subdir))
    if args.dag_id not in dagbag.dags:
        raise AirflowException('dag_id could not be found')
    dag = dagbag.dags[args.dag_id]

    if args.start_date:
        args.start_date = dateutil.parser.parse(args.start_date)
    if args.end_date:
        args.end_date = dateutil.parser.parse(args.end_date)

    # If only one date is passed, using same as start and end
    args.end_date = args.end_date or args.start_date
    args.start_date = args.start_date or args.end_date

    if args.task_regex:
        dag = dag.sub_dag(task_regex=args.task_regex,
                          include_upstream=not args.ignore_dependencies)

    if args.dry_run:
        print("Dry run of DAG {0} on {1}".format(args.dag_id, args.start_date))
        for task in dag.tasks:
            print("Task {0}".format(task.task_id))
            ti = TaskInstance(task, args.start_date)
            ti.dry_run()
    else:
        dag.run(start_date=args.start_date,
                end_date=args.end_date,
                mark_success=args.mark_success,
                include_adhoc=args.include_adhoc,
                local=args.local,
                donot_pickle=(args.donot_pickle
                              or conf.getboolean('core', 'donot_pickle')),
                ignore_dependencies=args.ignore_dependencies,
                pool=args.pool)
Exemplo n.º 15
0
    def _solid(context):  # pylint: disable=unused-argument
        if AIRFLOW_EXECUTION_DATE_STR not in context.pipeline_run.tags:
            raise DagsterInvariantViolationError(
                'Could not find "{AIRFLOW_EXECUTION_DATE_STR}" in pipeline tags "{tags}". Please '
                'add "{AIRFLOW_EXECUTION_DATE_STR}" to pipeline tags before executing'
                .format(
                    AIRFLOW_EXECUTION_DATE_STR=AIRFLOW_EXECUTION_DATE_STR,
                    tags=context.pipeline_run.tags,
                ))
        execution_date_str = context.pipeline_run.tags.get(
            AIRFLOW_EXECUTION_DATE_STR)

        check.str_param(execution_date_str, 'execution_date_str')
        try:
            execution_date = dateutil.parser.parse(execution_date_str)
        except ValueError:
            raise DagsterInvariantViolationError(
                'Could not parse execution_date "{execution_date_str}". Please use datetime format '
                'compatible with  dateutil.parser.parse.'.format(
                    execution_date_str=execution_date_str, ))
        except OverflowError:
            raise DagsterInvariantViolationError(
                'Date "{execution_date_str}" exceeds the largest valid C integer on the system.'
                .format(execution_date_str=execution_date_str, ))

        check.inst_param(execution_date, 'execution_date', datetime.datetime)

        with replace_airflow_logger_handlers():
            task_instance = TaskInstance(task=task,
                                         execution_date=execution_date)

            ti_context = task_instance.get_template_context()
            task.render_template_fields(ti_context)

            task.execute(ti_context)

            return None
Exemplo n.º 16
0
def _get_ti(
    task: BaseOperator,
    exec_date_or_run_id: str,
    map_index: int,
    *,
    create_if_necessary: CreateIfNecessary = False,
    session: Session = NEW_SESSION,
) -> Tuple[TaskInstance, bool]:
    """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
    if task.is_mapped:
        if map_index < 0:
            raise RuntimeError("No map_index passed to mapped task")
    elif map_index >= 0:
        raise RuntimeError("map_index passed to non-mapped task")
    dag_run, dr_created = _get_dag_run(
        dag=task.dag,
        exec_date_or_run_id=exec_date_or_run_id,
        create_if_necessary=create_if_necessary,
        session=session,
    )

    ti_or_none = dag_run.get_task_instance(task.task_id,
                                           map_index=map_index,
                                           session=session)
    if ti_or_none is None:
        if not create_if_necessary:
            raise TaskInstanceNotFound(
                f"TaskInstance for {task.dag.dag_id}, {task.task_id}, map={map_index} with "
                f"run_id or execution_date of {exec_date_or_run_id!r} not found"
            )
        # TODO: Validate map_index is in range?
        ti = TaskInstance(task, run_id=dag_run.run_id, map_index=map_index)
        ti.dag_run = dag_run
    else:
        ti = ti_or_none
    ti.refresh_from_task(task)
    return ti, dr_created
Exemplo n.º 17
0
    def test_rerun_failed_subdag(self):
        """
        When there is an existing DagRun with failed state, reset the DagRun and the
        corresponding TaskInstances
        """
        dag = DAG('parent', default_args=default_args)
        subdag = DAG('parent.test', default_args=default_args)
        subdag_task = SubDagOperator(task_id='test',
                                     subdag=subdag,
                                     dag=dag,
                                     poke_interval=1)
        dummy_task = DummyOperator(task_id='dummy', dag=subdag)

        with create_session() as session:
            dummy_task_instance = TaskInstance(
                task=dummy_task,
                execution_date=DEFAULT_DATE,
                state=State.FAILED,
            )
            session.add(dummy_task_instance)
            session.commit()

        sub_dagrun = subdag.create_dagrun(
            run_type=DagRunType.SCHEDULED,
            execution_date=DEFAULT_DATE,
            state=State.FAILED,
            external_trigger=True,
        )

        subdag_task._reset_dag_run_and_task_instances(
            sub_dagrun, execution_date=DEFAULT_DATE)

        dummy_task_instance.refresh_from_db()
        assert dummy_task_instance.state == State.NONE

        sub_dagrun.refresh_from_db()
        assert sub_dagrun.state == State.RUNNING
Exemplo n.º 18
0
 def test_render_template(self):
     json_str = '''
         {
             "type": "{{ params.index_type }}",
             "datasource": "{{ params.datasource }}",
             "spec": {
                 "dataSchema": {
                     "granularitySpec": {
                         "intervals": ["{{ ds }}/{{ macros.ds_add(ds, 1) }}"]
                     }
                 }
             }
         }
     '''
     operator = DruidOperator(task_id='spark_submit_job',
                              json_index_file=json_str,
                              params={
                                  'index_type': 'index_hadoop',
                                  'datasource': 'datasource_prd'
                              },
                              dag=self.dag)
     ti = TaskInstance(operator, DEFAULT_DATE)
     ti.render_templates()
     expected = '''
         {
             "type": "index_hadoop",
             "datasource": "datasource_prd",
             "spec": {
                 "dataSchema": {
                     "granularitySpec": {
                         "intervals": ["2017-01-01/2017-01-02"]
                     }
                 }
             }
         }
     '''
     self.assertEqual(expected, getattr(operator, 'json_index_file'))
Exemplo n.º 19
0
    def test_task_states_for_dag_run(self):

        dag2 = DagBag().dags['example_python_operator']
        task2 = dag2.get_task(task_id='print_the_context')
        defaut_date2 = timezone.make_aware(datetime(2016, 1, 9))
        dag2.clear()

        ti2 = TaskInstance(task2, defaut_date2)

        ti2.set_state(State.SUCCESS)
        ti_start = ti2.start_date
        ti_end = ti2.end_date

        with redirect_stdout(io.StringIO()) as stdout:
            task_command.task_states_for_dag_run(
                self.parser.parse_args([
                    'tasks', 'states-for-dag-run', 'example_python_operator',
                    defaut_date2.isoformat()
                ]))
        actual_out = stdout.getvalue()

        formatted_rows = [(
            'example_python_operator',
            '2016-01-09 00:00:00+00:00',
            'print_the_context',
            'success',
            ti_start,
            ti_end,
        )]

        expected = tabulate(
            formatted_rows,
            ['dag', 'exec_date', 'task', 'state', 'start_date', 'end_date'],
            tablefmt="plain")

        # Check that prints, and log messages, are shown
        self.assertIn(expected.replace("\n", ""), actual_out.replace("\n", ""))
Exemplo n.º 20
0
def task_test(args, dag=None):
    """Tests task for a given dag_id"""
    # We want log outout from operators etc to show up here. Normally
    # airflow.task would redirect to a file, but here we want it to propagate
    # up to the normal airflow handler.
    handlers = logging.getLogger('airflow.task').handlers
    already_has_stream_handler = False
    for handler in handlers:
        already_has_stream_handler = isinstance(handler, logging.StreamHandler)
        if already_has_stream_handler:
            break
    if not already_has_stream_handler:
        logging.getLogger('airflow.task').propagate = True

    dag = dag or get_dag(args)

    task = dag.get_task(task_id=args.task_id)
    # Add CLI provided task_params to task.params
    if args.task_params:
        passed_in_params = json.loads(args.task_params)
        task.params.update(passed_in_params)
    ti = TaskInstance(task, args.execution_date)

    try:
        if args.dry_run:
            ti.dry_run()
        else:
            ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
    except Exception:  # pylint: disable=broad-except
        if args.post_mortem:
            try:
                debugger = importlib.import_module("ipdb")
            except ImportError:
                debugger = importlib.import_module("pdb")
            debugger.post_mortem()
        else:
            raise
Exemplo n.º 21
0
    def test_task_states_for_dag_run(self):

        dag2 = DagBag().dags['example_python_operator']
        task2 = dag2.get_task(task_id='print_the_context')
        defaut_date2 = timezone.make_aware(datetime(2016, 1, 9))
        dag2.clear()

        ti2 = TaskInstance(task2, defaut_date2)

        ti2.set_state(State.SUCCESS)
        ti_start = ti2.start_date
        ti_end = ti2.end_date

        with redirect_stdout(io.StringIO()) as stdout:
            task_command.task_states_for_dag_run(
                self.parser.parse_args([
                    'tasks',
                    'states-for-dag-run',
                    'example_python_operator',
                    defaut_date2.isoformat(),
                    '--output',
                    "json",
                ]))
        actual_out = json.loads(stdout.getvalue())

        self.assertEqual(len(actual_out), 1)
        self.assertDictEqual(
            actual_out[0],
            {
                'dag_id': 'example_python_operator',
                'execution_date': '2016-01-09T00:00:00+00:00',
                'task_id': 'print_the_context',
                'state': 'success',
                'start_date': ti_start.isoformat(),
                'end_date': ti_end.isoformat(),
            },
        )
Exemplo n.º 22
0
    def setUp(self):
        super().setUp()
        self.local_log_location = 'local/log/location'
        self.filename_template = '{try_number}.log'
        self.log_id_template = '{dag_id}-{task_id}-{execution_date}-{try_number}'
        self.end_of_log_mark = 'end_of_log\n'
        self.write_stdout = False
        self.json_format = False
        self.json_fields = 'asctime,filename,lineno,levelname,message'
        self.es_task_handler = ElasticsearchTaskHandler(
            self.local_log_location,
            self.filename_template,
            self.log_id_template,
            self.end_of_log_mark,
            self.write_stdout,
            self.json_format,
            self.json_fields
        )

        self.es = elasticsearch.Elasticsearch(  # pylint: disable=invalid-name
            hosts=[{'host': 'localhost', 'port': 9200}]
        )
        self.index_name = 'test_index'
        self.doc_type = 'log'
        self.test_message = 'some random stuff'
        self.body = {'message': self.test_message, 'log_id': self.LOG_ID,
                     'offset': 1}

        self.es.index(index=self.index_name, doc_type=self.doc_type,
                      body=self.body, id=1)

        self.dag = DAG(self.DAG_ID, start_date=self.EXECUTION_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=self.dag)
        self.ti = TaskInstance(task=task, execution_date=self.EXECUTION_DATE)
        self.ti.try_number = 1
        self.ti.state = State.RUNNING
        self.addCleanup(self.dag.clear)
Exemplo n.º 23
0
    def test_log_file_template_with_run_task(self):
        """Verify that the taskinstance has the right context for log_filename_template"""

        with mock.patch.object(task_command, "_run_task_by_selected_method"):
            with conf_vars({('core', 'dags_folder'): self.dag_path}):
                # increment the try_number of the task to be run
                dag = DagBag().get_dag(self.dag_id)
                task = dag.get_task(self.task_id)
                with create_session() as session:
                    dag.create_dagrun(
                        execution_date=self.execution_date,
                        start_date=timezone.utcnow(),
                        state=State.RUNNING,
                        run_type=DagRunType.MANUAL,
                        session=session,
                    )
                    ti = TaskInstance(task, self.execution_date)
                    ti.refresh_from_db(session=session, lock_for_update=True)
                    ti.try_number = 1  # not running, so starts at 0
                    session.merge(ti)

                log_file_path = os.path.join(
                    os.path.dirname(self.ti_log_file_path), "2.log")

                try:
                    task_command.task_run(
                        self.parser.parse_args([
                            'tasks', 'run', self.dag_id, self.task_id,
                            '--local', self.execution_date_str
                        ]))

                    assert os.path.exists(log_file_path)
                finally:
                    try:
                        os.remove(log_file_path)
                    except OSError:
                        pass
Exemplo n.º 24
0
    def test_file_task_handler_running(self):
        def task_callable(ti, **kwargs):
            ti.log.info("test")

        dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
        task = PythonOperator(task_id='task_for_testing_file_log_handler',
                              dag=dag,
                              python_callable=task_callable,
                              provide_context=True)
        ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
        ti.try_number = 2
        ti.state = State.RUNNING

        logger = ti.log
        ti.log.disabled = False

        file_handler = next((handler for handler in logger.handlers
                             if handler.name == FILE_TASK_HANDLER), None)
        self.assertIsNotNone(file_handler)

        set_context(logger, ti)
        self.assertIsNotNone(file_handler.handler)
        # We expect set_context generates a file locally.
        log_filename = file_handler.handler.baseFilename
        self.assertTrue(os.path.isfile(log_filename))
        self.assertTrue(log_filename.endswith("2.log"), log_filename)

        logger.info("Test")

        # Return value of read must be a list.
        logs = file_handler.read(ti)
        self.assertTrue(isinstance(logs, list))
        # Logs for running tasks should show up too.
        self.assertEqual(len(logs), 2)

        # Remove the generated tmp log file.
        os.remove(log_filename)
Exemplo n.º 25
0
    def test_console_extra_link(self, mock_hook):
        training_op = MLEngineStartTrainingJobOperator(
            **self.TRAINING_DEFAULT_ARGS)

        ti = TaskInstance(
            task=training_op,
            execution_date=DEFAULT_DATE,
        )

        job_id = self.TRAINING_DEFAULT_ARGS['job_id']
        project_id = self.TRAINING_DEFAULT_ARGS['project_id']
        gcp_metadata = {
            "job_id": job_id,
            "project_id": project_id,
        }
        ti.xcom_push(key='gcp_metadata', value=gcp_metadata)

        assert (
            f"https://console.cloud.google.com/ai-platform/jobs/{job_id}?project={project_id}"
            == training_op.get_extra_links(DEFAULT_DATE,
                                           AIPlatformConsoleLink.name))

        assert '' == training_op.get_extra_links(datetime.datetime(2019, 1, 1),
                                                 AIPlatformConsoleLink.name)
Exemplo n.º 26
0
    def test_templates(self, _):
        dag_id = 'test_dag_id'
        args = {'start_date': DEFAULT_DATE}
        self.dag = DAG(dag_id, default_args=args)  # pylint:disable=attribute-defined-outside-init
        op = GoogleCloudStorageToGoogleCloudStorageTransferOperator(
            source_bucket='{{ dag.dag_id }}',
            destination_bucket='{{ dag.dag_id }}',
            description='{{ dag.dag_id }}',
            object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']},
            gcp_conn_id='{{ dag.dag_id }}',
            task_id=TASK_ID,
            dag=self.dag,
        )
        ti = TaskInstance(op, DEFAULT_DATE)
        ti.render_templates()
        self.assertEqual(dag_id, getattr(op, 'source_bucket'))
        self.assertEqual(dag_id, getattr(op, 'destination_bucket'))
        self.assertEqual(dag_id, getattr(op, 'description'))

        # pylint:disable=unsubscriptable-object
        self.assertEqual(dag_id, getattr(op, 'object_conditions')['exclude_prefixes'][0])
        # pylint:enable=unsubscriptable-object

        self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
Exemplo n.º 27
0
    def test_bigquery_operator_extra_serialized_field_when_single_query(self):
        with self.dag:
            BigQueryExecuteQueryOperator(
                task_id=TASK_ID,
                sql='SELECT * FROM test_table',
            )
        serialized_dag = SerializedDAG.to_dict(self.dag)
        self.assertIn("sql", serialized_dag["dag"]["tasks"][0])

        dag = SerializedDAG.from_dict(serialized_dag)
        simple_task = dag.task_dict[TASK_ID]
        self.assertEqual(getattr(simple_task, "sql"), 'SELECT * FROM test_table')

        #########################################################
        # Verify Operator Links work with Serialized Operator
        #########################################################

        # Check Serialized version of operator link
        self.assertEqual(
            serialized_dag["dag"]["tasks"][0]["_operator_extra_links"],
            [{'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink': {}}]
        )

        # Check DeSerialized version of operator link
        self.assertIsInstance(list(simple_task.operator_extra_links)[0], BigQueryConsoleLink)

        ti = TaskInstance(task=simple_task, execution_date=DEFAULT_DATE)
        ti.xcom_push('job_id', 12345)

        # check for positive case
        url = simple_task.get_extra_links(DEFAULT_DATE, BigQueryConsoleLink.name)
        self.assertEqual(url, 'https://console.cloud.google.com/bigquery?j=12345')

        # check for negative case
        url2 = simple_task.get_extra_links(datetime(2017, 1, 2), BigQueryConsoleLink.name)
        self.assertEqual(url2, '')
    def test_render_template_from_file(self):
        self.operator.job_flow_overrides = 'job.j2.json'
        self.operator.params = {'releaseLabel': '5.11.0'}

        ti = TaskInstance(self.operator, DEFAULT_DATE)
        ti.render_templates()

        self.emr_client_mock.run_job_flow.return_value = RUN_JOB_FLOW_SUCCESS_RETURN
        emr_session_mock = MagicMock()
        emr_session_mock.client.return_value = self.emr_client_mock
        boto3_session_mock = MagicMock(return_value=emr_session_mock)

        with patch('boto3.session.Session', boto3_session_mock):
            self.operator.execute(None)

        expected_args = {
            'Name':
            'test_job_flow',
            'ReleaseLabel':
            '5.11.0',
            'Steps': [{
                'Name': 'test_step',
                'ActionOnFailure': 'CONTINUE',
                'HadoopJarStep': {
                    'Jar':
                    'command-runner.jar',
                    'Args': [
                        '/usr/lib/spark/bin/run-example',
                        '2016-12-31',
                        '2017-01-01',
                    ]
                }
            }]
        }

        self.assertDictEqual(self.operator.job_flow_overrides, expected_args)
Exemplo n.º 29
0
def execute_tasks_in_dag(dag, tasks, run_id, execution_date):
    assert isinstance(dag, DAG)

    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(logging.Formatter(LOG_FORMAT))
    root = logging.getLogger("airflow.task.operators")
    root.setLevel(logging.DEBUG)
    root.addHandler(handler)

    dag_run = dag.create_dagrun(run_id=run_id, state="success", execution_date=execution_date)

    results = {}
    for task in tasks:
        ti = TaskInstance(task=task, execution_date=execution_date)
        context = ti.get_template_context()
        context["dag_run"] = dag_run

        try:
            results[ti] = task.execute(context)
        except AirflowSkipException as exc:
            results[ti] = exc

    return results
Exemplo n.º 30
0
    def setUp(self):
        super(TestLogView, self).setUp()
        # Make sure that the configure_logging is not cached
        self.old_modules = dict(sys.modules)

        # Create a custom logging configuration
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task'][
            'base_log_folder'] = os.path.normpath(
                os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder,
                                     "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class',
                 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()