def setUp(self):
     configuration.load_test_config()
     self.sagemaker = SageMakerBaseOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=config
     )
Ejemplo n.º 2
0
 def setUp(self):
     configuration.load_test_config()
     from airflow.contrib.hooks.ssh_hook import SSHHook
     hook = SSHHook(ssh_conn_id='ssh_default')
     hook.no_host_key_check = True
     args = {
         'owner': 'airflow',
         'start_date': DEFAULT_DATE,
         'provide_context': True
     }
     dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
     dag.schedule_interval = '@once'
     self.hook = hook
     self.dag = dag
     self.test_dir = "/tmp"
     self.test_local_dir = "/tmp/tmp2"
     self.test_remote_dir = "/tmp/tmp1"
     self.test_local_filename = 'test_local_file'
     self.test_remote_filename = 'test_remote_file'
     self.test_local_filepath = '{0}/{1}'.format(self.test_dir,
                                                 self.test_local_filename)
     # Local Filepath with Intermediate Directory
     self.test_local_filepath_int_dir = '{0}/{1}'.format(self.test_local_dir,
                                                         self.test_local_filename)
     self.test_remote_filepath = '{0}/{1}'.format(self.test_dir,
                                                  self.test_remote_filename)
     # Remote Filepath with Intermediate Directory
     self.test_remote_filepath_int_dir = '{0}/{1}'.format(self.test_remote_dir,
                                                          self.test_remote_filename)
    def setUp(self):
        super().setUp()
        self.remote_log_base = 's3://bucket/remote/log/location'
        self.remote_log_location = 's3://bucket/remote/log/location/1.log'
        self.remote_log_key = 'remote/log/location/1.log'
        self.local_log_location = 'local/log/location'
        self.filename_template = '{try_number}.log'
        self.s3_task_handler = S3TaskHandler(
            self.local_log_location,
            self.remote_log_base,
            self.filename_template
        )

        configuration.load_test_config()
        date = datetime(2016, 1, 1)
        self.dag = DAG('dag_for_testing_file_task_handler', start_date=date)
        task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=self.dag)
        self.ti = TaskInstance(task=task, execution_date=date)
        self.ti.try_number = 1
        self.ti.state = State.RUNNING
        self.addCleanup(self.dag.clear)

        self.conn = boto3.client('s3')
        # We need to create the bucket since this is all in Moto's 'virtual'
        # AWS account
        moto.core.moto_api_backend.reset()
        self.conn.create_bucket(Bucket="bucket")
 def setUp(self):
     configuration.load_test_config()
     args = {
         'owner': 'airflow',
         'start_date': datetime.datetime(2017, 1, 1)
     }
     self.dag = DAG('test_dag_id', default_args=args)
    def setUp(self):
        configuration.load_test_config()

        self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
                                        database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
                                        client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
                                        sleep_time=1)
 def setUp(self):
     configuration.load_test_config()
     self.sagemaker = SageMakerEndpointConfigOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=create_endpoint_config_params
     )
Ejemplo n.º 7
0
    def setUp(self):
        super(TestLogView, self).setUp()

        # Create a custom logging configuration
        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task']['base_log_folder'] = os.path.normpath(
            os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder, "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class', 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         models.Connection(
             conn_id='slack-webhook-default',
             extra='{"webhook_token": "your_token_here"}')
     )
Ejemplo n.º 9
0
 def setUp(self):
     configuration.load_test_config()
     args = {
         'owner': 'airflow',
         'start_date': DEFAULT_DATE
     }
     self.dag = DAG(TEST_DAG_ID, default_args=args)
Ejemplo n.º 10
0
 def setUp(self):
     session = requests.Session()
     adapter = requests_mock.Adapter()
     session.mount('mock', adapter)
     self.get_hook = HttpHook(method='GET')
     self.post_hook = HttpHook(method='POST')
     configuration.load_test_config()
Ejemplo n.º 11
0
 def setUp(self):
     conf.load_test_config()
     app = application.create_app(testing=True)
     app.config['WTF_CSRF_METHODS'] = []
     self.app = app.test_client()
     self.session = Session()
     models.DagBag().get_dag("example_bash_operator").sync_to_db()
    def setUp(self):
        configuration.load_test_config()
        from airflow.contrib.hooks.ssh_hook import SSHHook
        from airflow.hooks.S3_hook import S3Hook

        hook = SSHHook(ssh_conn_id='ssh_default')
        s3_hook = S3Hook('aws_default')
        hook.no_host_key_check = True
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
            'provide_context': True
        }
        dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
        dag.schedule_interval = '@once'

        self.hook = hook
        self.s3_hook = s3_hook

        self.ssh_client = self.hook.get_conn()
        self.sftp_client = self.ssh_client.open_sftp()

        self.dag = dag
        self.s3_bucket = BUCKET
        self.sftp_path = SFTP_PATH
        self.s3_key = S3_KEY
Ejemplo n.º 13
0
    def setUp(self):

        if sys.version_info[0] == 3:
            raise unittest.SkipTest('TestSparkSubmitHook won\'t work with '
                                    'python3. No need to test anything here')

        configuration.load_test_config()
        db.merge_conn(
            models.Connection(
                conn_id='spark_yarn_cluster', conn_type='spark',
                host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_default_mesos', conn_type='spark',
                host='mesos://host', port=5050)
        )

        db.merge_conn(
            models.Connection(
                conn_id='spark_home_set', conn_type='spark',
                host='yarn://yarn-master',
                extra='{"spark-home": "/opt/myspark"}')
        )

        db.merge_conn(
            models.Connection(
                conn_id='spark_home_not_set', conn_type='spark',
                host='yarn://yarn-master')
        )
Ejemplo n.º 14
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
             models.Connection(
                     conn_id='jdbc_default', conn_type='jdbc',
                     host='jdbc://localhost/', port=443,
                     extra='{"extra__jdbc__drv_path": "/path1/test.jar,/path2/t.jar2", "extra__jdbc__drv_clsname": "com.driver.main"}'))
Ejemplo n.º 15
0
 def setUp(self):
     configuration.load_test_config()
     self._upload_dataframe()
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     self.dag = DAG('test_dag_id', default_args=args)
     self.database = 'airflow'
     self.table = 'hive_server_hook'
     self.hql = """
     CREATE DATABASE IF NOT EXISTS {{ params.database }};
     USE {{ params.database }};
     DROP TABLE IF EXISTS {{ params.table }};
     CREATE TABLE IF NOT EXISTS {{ params.table }} (
         a int,
         b int)
     ROW FORMAT DELIMITED
     FIELDS TERMINATED BY ',';
     LOAD DATA LOCAL INPATH '{{ params.csv_path }}'
     OVERWRITE INTO TABLE {{ params.table }};
     """
     self.columns = ['{}.a'.format(self.table),
                     '{}.b'.format(self.table)]
     self.hook = HiveMetastoreHook()
     t = HiveOperator(
         task_id='HiveHook_' + str(random.randint(1, 10000)),
         params={
             'database': self.database,
             'table': self.table,
             'csv_path': self.local_path
         },
         hive_cli_conn_id='beeline_default',
         hql=self.hql, dag=self.dag)
     t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
           ignore_ti_state=True)
Ejemplo n.º 16
0
 def setUp(self):
     configuration.load_test_config()
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     self.dag = DAG('test_dag_id', default_args=args)
     self.next_day = (DEFAULT_DATE +
                      datetime.timedelta(days=1)).isoformat()[:10]
     self.database = 'airflow'
     self.partition_by = 'ds'
     self.table = 'static_babynames_partitioned'
     self.hql = """
     CREATE DATABASE IF NOT EXISTS {{ params.database }};
     USE {{ params.database }};
     DROP TABLE IF EXISTS {{ params.table }};
     CREATE TABLE IF NOT EXISTS {{ params.table }} (
         state string,
         year string,
         name string,
         gender string,
         num int)
     PARTITIONED BY ({{ params.partition_by }} string);
     ALTER TABLE {{ params.table }}
     ADD PARTITION({{ params.partition_by }}='{{ ds }}');
     """
     self.hook = HiveMetastoreHook()
     t = HiveOperator(
         task_id='HiveHook_' + str(random.randint(1, 10000)),
         params={
             'database': self.database,
             'table': self.table,
             'partition_by': self.partition_by
         },
         hive_cli_conn_id='beeline_default',
         hql=self.hql, dag=self.dag)
     t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
           ignore_ti_state=True)
Ejemplo n.º 17
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         models.Connection(
             conn_id='cassandra_test', conn_type='cassandra',
             host='host-1,host-2', port='9042', schema='test_keyspace',
             extra='{"load_balancing_policy":"TokenAwarePolicy"'))
    def setUp(self):
        super(TestElasticsearchTaskHandler, self).setUp()
        self.local_log_location = 'local/log/location'
        self.filename_template = '{try_number}.log'
        self.log_id_template = '{dag_id}-{task_id}-{execution_date}-{try_number}'
        self.end_of_log_mark = 'end_of_log\n'
        self.es_task_handler = ElasticsearchTaskHandler(
            self.local_log_location,
            self.filename_template,
            self.log_id_template,
            self.end_of_log_mark
        )

        self.es = elasticsearch.Elasticsearch(hosts=[{'host': 'localhost', 'port': 9200}])
        self.index_name = 'test_index'
        self.doc_type = 'log'
        self.test_message = 'some random stuff'
        self.body = {'message': self.test_message, 'log_id': self.LOG_ID,
                     'offset': 1}

        self.es.index(index=self.index_name, doc_type=self.doc_type,
                      body=self.body, id=1)

        configuration.load_test_config()
        self.dag = DAG(self.DAG_ID, start_date=self.EXECUTION_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=self.dag)
        self.ti = TaskInstance(task=task, execution_date=self.EXECUTION_DATE)
        self.ti.try_number = 1
        self.ti.state = State.RUNNING
        self.addCleanup(self.dag.clear)
Ejemplo n.º 19
0
    def test_xcom_disable_pickle_type(self):
        configuration.load_test_config()

        json_obj = {"key": "value"}
        execution_date = timezone.utcnow()
        key = "xcom_test1"
        dag_id = "test_dag1"
        task_id = "test_task1"

        configuration.set("core", "enable_xcom_pickling", "False")

        XCom.set(key=key,
                 value=json_obj,
                 dag_id=dag_id,
                 task_id=task_id,
                 execution_date=execution_date)

        ret_value = XCom.get_one(key=key,
                                 dag_id=dag_id,
                                 task_id=task_id,
                                 execution_date=execution_date)

        self.assertEqual(ret_value, json_obj)

        session = settings.Session()
        ret_value = session.query(XCom).filter(XCom.key == key, XCom.dag_id == dag_id,
                                               XCom.task_id == task_id,
                                               XCom.execution_date == execution_date
                                               ).first().value

        self.assertEqual(ret_value, json_obj)
Ejemplo n.º 20
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
             models.Connection(
                     conn_id='jira_default', conn_type='jira',
                     host='https://localhost/jira/', port=443,
                     extra='{"verify": "False", "project": "AIRFLOW"}'))
Ejemplo n.º 21
0
 def setUp(self):
     conf.load_test_config()
     self.app, self.appbuilder = application.create_app(testing=True)
     self.app.config['WTF_CSRF_ENABLED'] = False
     self.client = self.app.test_client()
     self.session = Session()
     self.login()
 def setUp(self):
     configuration.load_test_config()
     self.sagemaker = SageMakerModelOperator(
         task_id='test_sagemaker_operator',
         aws_conn_id='sagemaker_test_id',
         config=create_model_params
     )
    def setUp(self):
        configuration.load_test_config()
        db.merge_conn(
            models.Connection(
                conn_id='cassandra_test', conn_type='cassandra',
                host='host-1,host-2', port='9042', schema='test_keyspace',
                extra='{"load_balancing_policy":"TokenAwarePolicy"}'))
        db.merge_conn(
            models.Connection(
                conn_id='cassandra_default_with_schema', conn_type='cassandra',
                host='cassandra', port='9042', schema='s'))

        hook = CassandraHook("cassandra_default")
        session = hook.get_conn()
        cqls = [
            "DROP SCHEMA IF EXISTS s",
            """
                CREATE SCHEMA s WITH REPLICATION =
                    { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
            """,
        ]
        for cql in cqls:
            session.execute(cql)

        session.shutdown()
        hook.shutdown_cluster()
    def setUp(self):

        configuration.load_test_config()
        db.merge_conn(
            models.Connection(
                conn_id='spark_yarn_cluster', conn_type='spark',
                host='yarn://yarn-master',
                extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_k8s_cluster', conn_type='spark',
                host='k8s://https://k8s-master',
                extra='{"spark-home": "/opt/spark", ' +
                      '"deploy-mode": "cluster", ' +
                      '"namespace": "mynamespace"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_default_mesos', conn_type='spark',
                host='mesos://host', port=5050)
        )

        db.merge_conn(
            models.Connection(
                conn_id='spark_home_set', conn_type='spark',
                host='yarn://yarn-master',
                extra='{"spark-home": "/opt/myspark"}')
        )

        db.merge_conn(
            models.Connection(
                conn_id='spark_home_not_set', conn_type='spark',
                host='yarn://yarn-master')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_binary_set', conn_type='spark',
                host='yarn', extra='{"spark-binary": "custom-spark-submit"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_binary_and_home_set', conn_type='spark',
                host='yarn',
                extra='{"spark-home": "/path/to/spark_home", ' +
                      '"spark-binary": "custom-spark-submit"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_standalone_cluster', conn_type='spark',
                host='spark://spark-standalone-master:6066',
                extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_standalone_cluster_client_mode', conn_type='spark',
                host='spark://spark-standalone-master:6066',
                extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "client"}')
        )
    def setUp(self):
        configuration.load_test_config()

        self.sensor = AthenaSensor(task_id='test_athena_sensor',
                                   query_execution_id='abc',
                                   sleep_time=5,
                                   max_retires=1,
                                   aws_conn_id='aws_default')
Ejemplo n.º 26
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         models.Connection(
             conn_id='sqoop_test', conn_type='sqoop', schema='schema',
             host='rmdbs', port=5050, extra=json.dumps(self._config_json)
         )
     )
Ejemplo n.º 27
0
 def setUp(self):
     super(TestMountPoint, self).setUp()
     configuration.load_test_config()
     configuration.conf.set("webserver", "base_url", "http://localhost:8080/test")
     config = dict()
     config['WTF_CSRF_METHODS'] = []
     app = application.cached_app(config=config, testing=True)
     self.client = Client(app)
 def setUp(self):
     configuration.load_test_config()
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     dag = DAG('test_dag_id', default_args=args)
     self.dag = dag
     self.sql = 'SELECT 1'
     self.hook = AwsDynamoDBHook(
         aws_conn_id='aws_default', region_name='us-east-1')
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(
             conn_id='default-discord-webhook',
             host='https://discordapp.com/api/',
             extra='{"webhook_endpoint": "webhooks/00000/some-discord-token_000"}')
     )
    def setUp(self):

        configuration.load_test_config()
        db.merge_conn(
            models.Connection(
                conn_id='spark_default', conn_type='spark',
                host='yarn://yarn-master')
        )
Ejemplo n.º 31
0
    def setUp(self):
        configuration.load_test_config()
        try:
            configuration.conf.add_section("api")
        except DuplicateSectionError:
            pass

        configuration.conf.set("api",
                               "auth_backend",
                               "airflow.contrib.auth.backends.password_auth")

        self.app = application.create_app(testing=True)

        session = Session()
        user = models.User()
        password_user = PasswordUser(user)
        password_user.username = '******'
        password_user.password = '******'
        session.add(password_user)
        session.commit()
        session.close()
    def setUp(self):

        if sys.version_info[0] == 3:
            raise unittest.SkipTest('TestSparkSubmitHook won\'t work with '
                                    'python3. No need to test anything here')

        configuration.load_test_config()
        db.merge_conn(
            models.Connection(
                conn_id='spark_yarn_cluster', conn_type='spark',
                host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_default_mesos', conn_type='spark',
                host='mesos://host', port=5050)
        )

        db.merge_conn(
            models.Connection(
                conn_id='spark_home_set', conn_type='spark',
                host='yarn://yarn-master',
                extra='{"spark-home": "/opt/myspark"}')
        )

        db.merge_conn(
            models.Connection(
                conn_id='spark_home_not_set', conn_type='spark',
                host='yarn://yarn-master')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_binary_set', conn_type='spark',
                host='yarn', extra='{"spark-binary": "custom-spark-submit"}')
        )
        db.merge_conn(
            models.Connection(
                conn_id='spark_binary_and_home_set', conn_type='spark',
                host='yarn', extra='{"spark-home": "/path/to/spark_home", "spark-binary": "custom-spark-submit"}')
        )
Ejemplo n.º 33
0
    def setUp(self):
        super(TestLogView, self).setUp()

        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['file.task'][
            'base_log_folder'] = os.path.normpath(
                os.path.join(current_dir, 'test_logs'))
        logging.config.dictConfig(logging_config)

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
    def setUp(self):
        super(TestElasticsearchTaskHandler, self).setUp()
        self.local_log_location = 'local/log/location'
        self.filename_template = '{try_number}.log'
        self.log_id_template = '{dag_id}-{task_id}-{execution_date}-{try_number}'
        self.end_of_log_mark = 'end_of_log\n'
        self.write_stdout = False
        self.json_format = False
        self.json_fields = 'asctime,filename,lineno,levelname,message'
        self.es_task_handler = ElasticsearchTaskHandler(
            self.local_log_location, self.filename_template,
            self.log_id_template, self.end_of_log_mark, self.write_stdout,
            self.json_format, self.json_fields)

        self.es = elasticsearch.Elasticsearch(hosts=[{
            'host': 'localhost',
            'port': 9200
        }])
        self.index_name = 'test_index'
        self.doc_type = 'log'
        self.test_message = 'some random stuff'
        self.body = {
            'message': self.test_message,
            'log_id': self.LOG_ID,
            'offset': 1
        }

        self.es.index(index=self.index_name,
                      doc_type=self.doc_type,
                      body=self.body,
                      id=1)

        configuration.load_test_config()
        self.dag = DAG(self.DAG_ID, start_date=self.EXECUTION_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=self.dag)
        self.ti = TaskInstance(task=task, execution_date=self.EXECUTION_DATE)
        self.ti.try_number = 1
        self.ti.state = State.RUNNING
        self.addCleanup(self.dag.clear)
    def setUpClass(cls):
        super().setUpClass()
        configuration.load_test_config()
        dag_id = "extract_document_feature_prepare_dag"
        cls.prepare_dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)

        get_list = GetEDINETDocumentListOperator(
                task_id="get_document_list", dag=cls.prepare_dag)
        get_document = GetEDINETDocumentSensor(
                max_retrieve=3, document_ids=("S100E2NM","S100E2S2"),
                task_id="get_document", dag=cls.prepare_dag, poke_interval=2)
        register_document = RegisterDocumentOperator(
                task_id="register_document", dag=cls.prepare_dag)
        extract_feature = ExtractDocumentFeaturesOperator(
                report_kinds=("annual",),
                task_id="extract_feature", dag=cls.prepare_dag)

        cls.prepare_dag.clear()
        get_list.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        get_document.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        register_document.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
        extract_feature.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
Ejemplo n.º 36
0
    def setUp(self):
        super(TestLogView, self).setUp()
        # Make sure that the configure_logging is not cached
        self.old_modules = dict(sys.modules)

        conf.load_test_config()

        # Create a custom logging configuration
        configuration.load_test_config()
        logging_config = copy.deepcopy(DEFAULT_LOGGING_CONFIG)
        current_dir = os.path.dirname(os.path.abspath(__file__))
        logging_config['handlers']['task'][
            'base_log_folder'] = os.path.normpath(
                os.path.join(current_dir, 'test_logs'))
        logging_config['handlers']['task']['filename_template'] = \
            '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts | replace(":", ".") }}/{{ try_number }}.log'

        # Write the custom logging configuration to a file
        self.settings_folder = tempfile.mkdtemp()
        settings_file = os.path.join(self.settings_folder,
                                     "airflow_local_settings.py")
        new_logging_file = "LOGGING_CONFIG = {}".format(logging_config)
        with open(settings_file, 'w') as handle:
            handle.writelines(new_logging_file)
        sys.path.append(self.settings_folder)
        conf.set('core', 'logging_config_class',
                 'airflow_local_settings.LOGGING_CONFIG')

        app = application.create_app(testing=True)
        self.app = app.test_client()
        self.session = Session()
        from airflow.www.views import dagbag
        dag = DAG(self.DAG_ID, start_date=self.DEFAULT_DATE)
        task = DummyOperator(task_id=self.TASK_ID, dag=dag)
        dagbag.bag_dag(dag, parent_dag=dag, root_dag=dag)
        ti = TaskInstance(task=task, execution_date=self.DEFAULT_DATE)
        ti.try_number = 1
        self.session.merge(ti)
        self.session.commit()
Ejemplo n.º 37
0
    def setUp(self):
        configuration.load_test_config()
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
            'provide_context': True
        }
        dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
        dag.schedule_interval = '@once'
        self.dag = dag

        self.sensor = gcs_sensor.GoogleCloudStorageUploadSessionCompleteSensor(
            task_id='sensor',
            bucket='test-bucket',
            prefix='test-prefix/path',
            inactivity_period=12,
            poke_interval=10,
            min_objects=1,
            allow_delete=False,
            previous_num_objects=0,
            dag=self.dag)
        self.last_mocked_date = datetime(2019, 4, 24, 0, 0, 0)
    def setUp(self):

        configuration.load_test_config()
        db.merge_conn(
            models.Connection(
                conn_id='spark_yarn_cluster',
                conn_type='spark',
                host='yarn://yarn-master',
                extra='{"queue": "root.etl", "deploy-mode": "cluster"}'))
        db.merge_conn(
            models.Connection(conn_id='spark_default_mesos',
                              conn_type='spark',
                              host='mesos://host',
                              port=5050))

        db.merge_conn(
            models.Connection(conn_id='spark_home_set',
                              conn_type='spark',
                              host='yarn://yarn-master',
                              extra='{"spark-home": "/opt/myspark"}'))

        db.merge_conn(
            models.Connection(conn_id='spark_home_not_set',
                              conn_type='spark',
                              host='yarn://yarn-master'))
        db.merge_conn(
            models.Connection(conn_id='spark_binary_set',
                              conn_type='spark',
                              host='yarn',
                              extra='{"spark-binary": "custom-spark-submit"}'))
        db.merge_conn(
            models.Connection(
                conn_id='spark_binary_and_home_set',
                conn_type='spark',
                host='yarn',
                extra=
                '{"spark-home": "/path/to/spark_home", "spark-binary": "custom-spark-submit"}'
            ))
Ejemplo n.º 39
0
 def test_templates(self, _):
     dag_id = 'test_dag_id'
     configuration.load_test_config()
     args = {'start_date': DEFAULT_DATE}
     self.dag = DAG(dag_id, default_args=args)
     op = GoogleCloudStorageToGoogleCloudStorageTransferOperator(
         source_bucket='{{ dag.dag_id }}',
         destination_bucket='{{ dag.dag_id }}',
         description='{{ dag.dag_id }}',
         object_conditions={'exclude_prefixes': ['{{ dag.dag_id }}']},
         gcp_conn_id='{{ dag.dag_id }}',
         task_id=TASK_ID,
         dag=self.dag,
     )
     ti = TaskInstance(op, DEFAULT_DATE)
     ti.render_templates()
     self.assertEqual(dag_id, getattr(op, 'source_bucket'))
     self.assertEqual(dag_id, getattr(op, 'destination_bucket'))
     self.assertEqual(dag_id, getattr(op, 'description'))
     self.assertEqual(
         dag_id,
         getattr(op, 'object_conditions')['exclude_prefixes'][0])
     self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
Ejemplo n.º 40
0
 def test_instance_start_with_templates(self, _):
     dag_id = 'test_dag_id'
     configuration.load_test_config()
     args = {
         'start_date': DEFAULT_DATE
     }
     self.dag = DAG(dag_id, default_args=args)
     op = GceInstanceStartOperator(
         project_id='{{ dag.dag_id }}',
         zone='{{ dag.dag_id }}',
         resource_id='{{ dag.dag_id }}',
         gcp_conn_id='{{ dag.dag_id }}',
         api_version='{{ dag.dag_id }}',
         task_id='id',
         dag=self.dag
     )
     ti = TaskInstance(op, DEFAULT_DATE)
     ti.render_templates()
     self.assertEqual(dag_id, getattr(op, 'project_id'))
     self.assertEqual(dag_id, getattr(op, 'zone'))
     self.assertEqual(dag_id, getattr(op, 'resource_id'))
     self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
     self.assertEqual(dag_id, getattr(op, 'api_version'))
    def setUp(self):
        configuration.load_test_config()
        db.merge_conn(
            Connection(conn_id='azure_container_instance_test',
                       conn_type='azure_container_instances',
                       login='******',
                       password='******',
                       extra=json.dumps({
                           'tenantId': 'tenant_id',
                           'subscriptionId': 'subscription_id'
                       })))

        self.resources = ResourceRequirements(
            requests=ResourceRequests(memory_in_gb='4', cpu='1'))
        with patch(
                'azure.common.credentials.ServicePrincipalCredentials.__init__',
                autospec=True,
                return_value=None):
            with patch(
                    'azure.mgmt.containerinstance.ContainerInstanceManagementClient'
            ):
                self.testHook = AzureContainerInstanceHook(
                    conn_id='azure_container_instance_test')
Ejemplo n.º 42
0
    def setUp(self):
        configuration.load_test_config()
        db.merge_conn(
            Connection(
                conn_id='mongo_test', conn_type='mongo',
                host='mongo', port='27017', schema='test'))

        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE
        }
        self.dag = DAG('test_dag_id', default_args=args)

        hook = MongoHook('mongo_test')
        hook.insert_one('foo', {'bar': 'baz'})

        self.sensor = MongoSensor(
            task_id='test_task',
            mongo_conn_id='mongo_test',
            dag=self.dag,
            collection='foo',
            query={'bar': 'baz'}
        )
Ejemplo n.º 43
0
    def setUp(self):
        load_test_config()

        db.merge_conn(
            models.Connection(
                conn_id='google_test',
                host='google',
                conn_type="google_cloud_platform",
                schema='refresh_token',
                login='******',
                password='******'
            )
        )
        db.merge_conn(
            models.Connection(
                conn_id='s3_test',
                conn_type='s3',
                schema='test',
                extra='{"aws_access_key_id": "aws_access_key_id", "aws_secret_access_key":'
                      ' "aws_secret_access_key"}'
            )
        )

        self.kwargs = {
            'gcp_conn_id': 'google_test',
            'google_api_service_name': 'test_service',
            'google_api_service_version': 'v3',
            'google_api_endpoint_path': 'analyticsreporting.reports.batchGet',
            'google_api_endpoint_params': {},
            'google_api_pagination': False,
            'google_api_num_retries': 0,
            'aws_conn_id': 's3_test',
            's3_destination_key': 'test/google_api_to_s3_test.csv',
            's3_overwrite': True,
            'task_id': 'task_id',
            'dag': None
        }
Ejemplo n.º 44
0
 def setUp(self):
     configuration.load_test_config()
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     self.dag = DAG('test_dag_id', default_args=args)
     self.next_day = (DEFAULT_DATE +
                      datetime.timedelta(days=1)).isoformat()[:10]
     self.database = 'airflow'
     self.partition_by = 'ds'
     self.table = 'static_babynames_partitioned'
     self.hql = """
     CREATE DATABASE IF NOT EXISTS {{ params.database }};
     USE {{ params.database }};
     DROP TABLE IF EXISTS {{ params.table }};
     CREATE TABLE IF NOT EXISTS {{ params.table }} (
         state string,
         year string,
         name string,
         gender string,
         num int)
     PARTITIONED BY ({{ params.partition_by }} string);
     ALTER TABLE {{ params.table }}
     ADD PARTITION({{ params.partition_by }}='{{ ds }}');
     """
     self.hook = HiveMetastoreHook()
     t = operators.hive_operator.HiveOperator(
         task_id='HiveHook_' + str(random.randint(1, 10000)),
         params={
             'database': self.database,
             'table': self.table,
             'partition_by': self.partition_by
         },
         hive_cli_conn_id='beeline_default',
         hql=self.hql,
         dag=self.dag)
     t.run(start_date=DEFAULT_DATE,
           end_date=DEFAULT_DATE,
           ignore_ti_state=True)
    def setUp(self):
        configuration.load_test_config()

        hook = SSHHook(ssh_conn_id='ssh_default')
        s3_hook = S3Hook('aws_default')
        hook.no_host_key_check = True
        args = {
            'owner': 'airflow',
            'start_date': DEFAULT_DATE,
            'provide_context': True
        }
        dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once', default_args=args)
        dag.schedule_interval = '@once'

        self.hook = hook
        self.s3_hook = s3_hook

        self.ssh_client = self.hook.get_conn()
        self.sftp_client = self.ssh_client.open_sftp()

        self.dag = dag
        self.s3_bucket = BUCKET
        self.sftp_path = SFTP_PATH
        self.s3_key = S3_KEY
Ejemplo n.º 46
0
    def setUp(self, aws_hook_mock):
        configuration.load_test_config()

        self.aws_hook_mock = aws_hook_mock
        self.ecs = ECSOperator(
            task_id='task',
            task_definition='t',
            cluster='c',
            overrides={},
            aws_conn_id=None,
            region_name='eu-west-1',
            group='group',
            placement_constraints=[
                {
                    'expression': 'attribute:ecs.instance-type =~ t2.*',
                    'type': 'memberOf'
                }
            ],
            network_configuration={
                'awsvpcConfiguration': {
                    'securityGroups': ['sg-123abc']
                }
            }
        )
    def setUp(self):
        super(TestS3TaskHandler, self).setUp()
        self.remote_log_location = 's3://bucket/remote/log/location'
        self.remote_log_key = 'remote/log/location'
        self.local_log_location = 'local/log/location'
        self.filename_template = '{try_number}.log'
        self.s3_task_handler = S3TaskHandler(self.local_log_location,
                                             self.remote_log_location,
                                             self.filename_template)

        configuration.load_test_config()
        date = datetime(2016, 1, 1)
        self.dag = DAG('dag_for_testing_file_task_handler', start_date=date)
        task = DummyOperator(task_id='task_for_testing_file_log_handler',
                             dag=self.dag)
        self.ti = TaskInstance(task=task, execution_date=date)
        self.ti.try_number = 1
        self.addCleanup(self.dag.clear)

        self.conn = boto3.client('s3')
        # We need to create the bucket since this is all in Moto's 'virtual'
        # AWS account
        moto.core.moto_api_backend.reset()
        self.conn.create_bucket(Bucket="bucket")
Ejemplo n.º 48
0
    def setUp(self, aws_hook_mock):
        configuration.load_test_config()

        self.aws_hook_mock = aws_hook_mock
        self.ecs_operator_args = {
            'task_id': 'task',
            'task_definition': 't',
            'cluster': 'c',
            'overrides': {},
            'aws_conn_id': None,
            'region_name': 'eu-west-1',
            'group': 'group',
            'placement_constraints': [{
                'expression': 'attribute:ecs.instance-type =~ t2.*',
                'type': 'memberOf'
            }],
            'network_configuration': {
                'awsvpcConfiguration': {
                    'securityGroups': ['sg-123abc'],
                    'subnets': ['subnet-123456ab']
                }
            }
        }
        self.ecs = ECSOperator(**self.ecs_operator_args)
Ejemplo n.º 49
0
 def setUp(self):
     configuration.load_test_config()
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     dag = DAG(TEST_DAG_ID, default_args=args)
     self.dag = dag
Ejemplo n.º 50
0
 def setUp(self):
     super().setUp()
     configuration.load_test_config()
     app, _ = application.create_app(testing=True)
     self.app = app.test_client()
Ejemplo n.º 51
0
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import print_function

from airflow import DAG, configuration, operators
from airflow.utils.tests import skipUnlessImported
from airflow.utils import timezone

import os
import mock
import unittest

configuration.load_test_config()

DEFAULT_DATE = timezone.datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = 'unit_test_dag'


@skipUnlessImported('airflow.operators.mysql_operator', 'MySqlOperator')
class MySqlTest(unittest.TestCase):
    def setUp(self):
        configuration.load_test_config()
        args = {
            'owner': 'airflow',
            'mysql_conn_id': 'airflow_db',
            'start_date': DEFAULT_DATE
Ejemplo n.º 52
0
 def setUpClass(cls):
     os.environ['AIRFLOW__TESTSECTION__TESTKEY'] = 'testvalue'
     os.environ['AIRFLOW__TESTSECTION__TESTPERCENT'] = 'with%percent'
     configuration.load_test_config()
     conf.set('core', 'percent', 'with%%inside')
Ejemplo n.º 53
0
def airflow_init_db(airflow_home):
    configuration.load_test_config()
    initdb()
Ejemplo n.º 54
0
 def setUp(self):
     configuration.load_test_config()
     self.channel_mock = mock.patch('grpc.Channel').start()
Ejemplo n.º 55
0
 def setUp(self):
     configuration.load_test_config()
 def setUp(self):
     configuration.load_test_config()
     self.s3_test_url = "s3://test/this/is/not/a-real-key.txt"
Ejemplo n.º 57
0
 def setUp(self):
     configuration.load_test_config()
     self.hook = SFTPHook()
     os.makedirs(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
     with open(os.path.join(TMP_PATH, TMP_FILE_FOR_TESTS), 'a') as f:
         f.write('Test file')
 def setUp(self):
     super(TestSnowflakeOperator, self).setUp()
     configuration.load_test_config()
     args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     dag = DAG(TEST_DAG_ID, default_args=args)
     self.dag = dag
Ejemplo n.º 59
0
 def setUp(self):
     configuration.load_test_config()
     self.dagbag = DagBag(dag_folder=DEV_NULL, include_examples=True)
     self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
     dag = DAG(TEST_DAG_ID, default_args=self.args)
     self.dag = dag
Ejemplo n.º 60
0
 def setUp(self):
     super(TestApiExperimental, self).setUp()
     configuration.load_test_config()
     app = application.create_app(testing=True)
     self.app = app.test_client()