Пример #1
0
    def registered(self, driver, frameworkId, masterInfo):
        self.log.info(
            "AirflowScheduler registered to Mesos with framework ID %s",
            frameworkId.value)

        if configuration.conf.getboolean('mesos', 'CHECKPOINT') and \
                configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'):
            # Import here to work around a circular import error
            from airflow.models.connection import Connection

            # Update the Framework ID in the database.
            session = Session()
            conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name()
            connection = Session.query(Connection).filter_by(
                conn_id=conn_id).first()
            if connection is None:
                connection = Connection(conn_id=conn_id,
                                        conn_type='mesos_framework-id',
                                        extra=frameworkId.value)
            else:
                connection.extra = frameworkId.value

            session.add(connection)
            session.commit()
            Session.remove()
 def test_create_operator_with_wrong_parameters(self,
                                                project_id,
                                                location,
                                                instance_name,
                                                database_type,
                                                use_proxy,
                                                use_ssl,
                                                sql,
                                                message,
                                                get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?"
         "database_type={database_type}&"
         "project_id={project_id}&location={location}&instance={instance_name}&"
         "use_proxy={use_proxy}&use_ssl={use_ssl}".
         format(database_type=database_type,
                project_id=project_id,
                location=location,
                instance_name=instance_name,
                use_proxy=use_proxy,
                use_ssl=use_ssl))
     get_connections.return_value = [connection]
     with self.assertRaises(AirflowException) as cm:
         CloudSqlQueryOperator(
             sql=sql,
             task_id='task_id'
         )
     err = cm.exception
     self.assertIn(message, str(err))
Пример #3
0
 def test_create_operator_with_correct_parameters_mysql_ssl(self, get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=False&use_ssl=True&sslcert=/bin/bash&"
         "sslkey=/bin/bash&sslrootcert=/bin/bash")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(
         sql=['SELECT * FROM TABLE'],
         task_id='task_id'
     )
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('mysql', conn.conn_type)
     self.assertEqual('8.8.8.8', conn.host)
     self.assertEqual(3200, conn.port)
     self.assertEqual('testdb', conn.schema)
     self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert'])
     self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key'])
     self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca'])
Пример #4
0
 def test_create_operator_with_correct_parameters_mysql_proxy_socket(self,
                                                                     get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=True&sql_proxy_use_tcp=False")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(
         sql=['SELECT * FROM TABLE'],
         task_id='task_id'
     )
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('mysql', conn.conn_type)
     self.assertEqual('localhost', conn.host)
     self.assertIn('/tmp', conn.extra_dejson['unix_socket'])
     self.assertIn('example-project:europe-west1:testdb',
                   conn.extra_dejson['unix_socket'])
     self.assertIsNone(conn.port)
     self.assertEqual('testdb', conn.schema)
Пример #5
0
def _create_connection(conn_id: str, value: Any):
    """
    Creates a connection based on a URL or JSON object.
    """
    from airflow.models.connection import Connection

    if isinstance(value, str):
        return Connection(conn_id=conn_id, uri=value)
    if isinstance(value, dict):
        connection_parameter_names = get_connection_parameter_names()
        current_keys = set(value.keys())
        if not current_keys.issubset(connection_parameter_names):
            illegal_keys = current_keys - connection_parameter_names
            illegal_keys_list = ", ".join(illegal_keys)
            raise AirflowException(
                f"The object have illegal keys: {illegal_keys_list}. "
                f"The dictionary can only contain the following keys: {connection_parameter_names}"
            )

        if "conn_id" in current_keys and conn_id != value["conn_id"]:
            raise AirflowException(
                f"Mismatch conn_id. "
                f"The dictionary key has the value: {value['conn_id']}. "
                f"The item has the value: {conn_id}.")
        value["conn_id"] = conn_id
        return Connection(**value)
    raise AirflowException(
        f"Unexpected value type: {type(value)}. The connection can only be defined using a string or object."
    )
 def test_create_operator_with_correct_parameters_mysql_ssl(self, get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=False&use_ssl=True&sslcert=/bin/bash&"
         "sslkey=/bin/bash&sslrootcert=/bin/bash")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(
         sql=['SELECT * FROM TABLE'],
         task_id='task_id'
     )
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('mysql', conn.conn_type)
     self.assertEqual('8.8.8.8', conn.host)
     self.assertEqual(3200, conn.port)
     self.assertEqual('testdb', conn.schema)
     self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert'])
     self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key'])
     self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca'])
Пример #7
0
    def setUp(self):
        configuration.load_test_config()
        db.merge_conn(
            Connection(conn_id='cassandra_test',
                       conn_type='cassandra',
                       host='host-1,host-2',
                       port='9042',
                       schema='test_keyspace',
                       extra='{"load_balancing_policy":"TokenAwarePolicy"}'))
        db.merge_conn(
            Connection(conn_id='cassandra_default_with_schema',
                       conn_type='cassandra',
                       host='cassandra',
                       port='9042',
                       schema='s'))

        hook = CassandraHook("cassandra_default")
        session = hook.get_conn()
        cqls = [
            "DROP SCHEMA IF EXISTS s",
            """
                CREATE SCHEMA s WITH REPLICATION =
                    { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }
            """,
        ]
        for cql in cqls:
            session.execute(cql)

        session.shutdown()
        hook.shutdown_cluster()
Пример #8
0
 def test_create_operator_with_wrong_parameters(self,
                                                project_id,
                                                location,
                                                instance_name,
                                                database_type,
                                                use_proxy,
                                                use_ssl,
                                                sql,
                                                message,
                                                get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type={database_type}&"
         "project_id={project_id}&location={location}&instance={instance_name}&"
         "use_proxy={use_proxy}&use_ssl={use_ssl}".
         format(database_type=database_type,
                project_id=project_id,
                location=location,
                instance_name=instance_name,
                use_proxy=use_proxy,
                use_ssl=use_ssl))
     get_connections.return_value = [connection]
     with self.assertRaises(AirflowException) as cm:
         CloudSqlQueryOperator(
             sql=sql,
             task_id='task_id'
         )
     err = cm.exception
     self.assertIn(message, str(err))
 def test_create_operator_with_correct_parameters_mysql_proxy_socket(self,
                                                                     get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=mysql&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=True&sql_proxy_use_tcp=False")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(
         sql=['SELECT * FROM TABLE'],
         task_id='task_id'
     )
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('mysql', conn.conn_type)
     self.assertEqual('localhost', conn.host)
     self.assertIn('/tmp', conn.extra_dejson['unix_socket'])
     self.assertIn('example-project:europe-west1:testdb',
                   conn.extra_dejson['unix_socket'])
     self.assertIsNone(conn.port)
     self.assertEqual('testdb', conn.schema)
    def test_cloudsql_hook_delete_connection_on_exception(
            self, get_connections, run, get_connection, delete_connection):
        connection = Connection()
        connection.parse_from_uri(
            "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=mysql&"
            "project_id=example-project&location=europe-west1&instance=testdb&"
            "use_proxy=False")
        get_connection.return_value = connection

        db_connection = Connection()
        db_connection.host = "127.0.0.1"
        db_connection.set_extra(
            json.dumps({
                "project_id": "example-project",
                "location": "europe-west1",
                "instance": "testdb",
                "database_type": "mysql"
            }))
        get_connections.return_value = [db_connection]
        run.side_effect = Exception("Exception when running a query")
        operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'],
                                         task_id='task_id')
        with self.assertRaises(Exception) as cm:
            operator.execute(None)
        err = cm.exception
        self.assertEqual("Exception when running a query", str(err))
        delete_connection.assert_called_once_with()
Пример #11
0
class TestCloudSqlDatabaseHook(unittest.TestCase):
    @mock.patch(
        'airflow.contrib.hooks.gcp_sql_hook.CloudSqlDatabaseHook.get_connection'
    )
    def setUp(self, m):
        super().setUp()

        self.sql_connection = Connection(
            conn_id='my_gcp_sql_connection',
            conn_type='gcpcloudsql',
            login='******',
            password='******',
            host='host',
            schema='schema',
            extra='{"database_type":"postgres", "location":"my_location", '
            '"instance":"my_instance", "use_proxy": true, '
            '"project_id":"my_project"}')
        self.connection = Connection(
            conn_id='my_gcp_connection',
            conn_type='google_cloud_platform',
        )
        scopes = [
            "https://www.googleapis.com/auth/pubsub",
            "https://www.googleapis.com/auth/datastore",
            "https://www.googleapis.com/auth/bigquery",
            "https://www.googleapis.com/auth/devstorage.read_write",
            "https://www.googleapis.com/auth/logging.write",
            "https://www.googleapis.com/auth/cloud-platform",
        ]
        conn_extra = {
            "extra__google_cloud_platform__scope":
            ",".join(scopes),
            "extra__google_cloud_platform__project":
            "your-gcp-project",
            "extra__google_cloud_platform__key_path":
            '/var/local/google_cloud_default.json'
        }
        conn_extra_json = json.dumps(conn_extra)
        self.connection.set_extra(conn_extra_json)

        m.side_effect = [self.sql_connection, self.connection]
        self.db_hook = CloudSqlDatabaseHook(
            gcp_cloudsql_conn_id='my_gcp_sql_connection',
            gcp_conn_id='my_gcp_connection')

    def test_get_sqlproxy_runner(self):
        self.db_hook._generate_connection_uri()
        sqlproxy_runner = self.db_hook.get_sqlproxy_runner()
        self.assertEqual(sqlproxy_runner.gcp_conn_id, self.connection.conn_id)
        project = self.sql_connection.extra_dejson['project_id']
        location = self.sql_connection.extra_dejson['location']
        instance = self.sql_connection.extra_dejson['instance']
        instance_spec = "{project}:{location}:{instance}".format(
            project=project, location=location, instance=instance)
        self.assertEqual(sqlproxy_runner.instance_specification, instance_spec)
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(conn_id='slack-webhook-default',
                    extra='{"webhook_token": "your_token_here"}'))
     db.merge_conn(
         Connection(conn_id='slack-webhook-url',
                    host='https://hooks.slack.com/services/T000/B000/XXX'))
     db.merge_conn(
         Connection(conn_id='slack-webhook-host',
                    host='https://hooks.slack.com/services/T000/'))
    def setUp(self):

        configuration.load_test_config()
        db.merge_conn(
            Connection(
                conn_id='spark_yarn_cluster',
                conn_type='spark',
                host='yarn://yarn-master',
                extra='{"queue": "root.etl", "deploy-mode": "cluster"}'))
        db.merge_conn(
            Connection(conn_id='spark_k8s_cluster',
                       conn_type='spark',
                       host='k8s://https://k8s-master',
                       extra='{"spark-home": "/opt/spark", ' +
                       '"deploy-mode": "cluster", ' +
                       '"namespace": "mynamespace"}'))
        db.merge_conn(
            Connection(conn_id='spark_default_mesos',
                       conn_type='spark',
                       host='mesos://host',
                       port=5050))

        db.merge_conn(
            Connection(conn_id='spark_home_set',
                       conn_type='spark',
                       host='yarn://yarn-master',
                       extra='{"spark-home": "/opt/myspark"}'))

        db.merge_conn(
            Connection(conn_id='spark_home_not_set',
                       conn_type='spark',
                       host='yarn://yarn-master'))
        db.merge_conn(
            Connection(conn_id='spark_binary_set',
                       conn_type='spark',
                       host='yarn',
                       extra='{"spark-binary": "custom-spark-submit"}'))
        db.merge_conn(
            Connection(conn_id='spark_binary_and_home_set',
                       conn_type='spark',
                       host='yarn',
                       extra='{"spark-home": "/path/to/spark_home", ' +
                       '"spark-binary": "custom-spark-submit"}'))
        db.merge_conn(
            Connection(
                conn_id='spark_standalone_cluster',
                conn_type='spark',
                host='spark://spark-standalone-master:6066',
                extra=
                '{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}'
            ))
        db.merge_conn(
            Connection(
                conn_id='spark_standalone_cluster_client_mode',
                conn_type='spark',
                host='spark://spark-standalone-master:6066',
                extra=
                '{"spark-home": "/path/to/spark_home", "deploy-mode": "client"}'
            ))
Пример #14
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(conn_id='wasb_test_key',
                    conn_type='wasb',
                    login='******',
                    password='******'))
     db.merge_conn(
         Connection(conn_id='wasb_test_sas_token',
                    conn_type='wasb',
                    login='******',
                    extra=json.dumps({'sas_token': 'token'})))
Пример #15
0
    def create_connection(self, session=None):
        """
        Create connection in the Connection table, according to whether it uses
        proxy, TCP, UNIX sockets, SSL. Connection ID will be randomly generated.

        :param session: Session of the SQL Alchemy ORM (automatically generated with
                        decorator).
        """
        connection = Connection(conn_id=self.db_conn_id)
        uri = self._generate_connection_uri()
        self.log.info("Creating connection {}".format(self.db_conn_id))
        connection.parse_from_uri(uri)
        session.add(connection)
        session.commit()
Пример #16
0
    def create_connection(self, session=None):
        """
        Create connection in the Connection table, according to whether it uses
        proxy, TCP, UNIX sockets, SSL. Connection ID will be randomly generated.

        :param session: Session of the SQL Alchemy ORM (automatically generated with
                        decorator).
        """
        connection = Connection(conn_id=self.db_conn_id)
        uri = self._generate_connection_uri()
        self.log.info("Creating connection {}".format(self.db_conn_id))
        connection.parse_from_uri(uri)
        session.add(connection)
        session.commit()
Пример #17
0
 def test_api_key_required(self, mock_get_connection, mock_initialize):
     mock_get_connection.return_value = Connection()
     with self.assertRaises(AirflowException) as ctx:
         DatadogHook()
     self.assertEqual(
         str(ctx.exception),
         'api_key must be specified in the Datadog connection details')
def copy_brazil_data_file(origin_host, origin_filepath, dest_bucket, dest_key):
    """Copy Brazil data file to a local bucket.
    Copy the source file which contains detailed data about Brazil to
    an AWS S3 bucket to make it available to AWS EMR.
    args:
    origin_host (str): host where the source file is in
    origin_filepath (str): full path to the file in the host
    dest_bucket (str): name of the bucket to store the file
    dest_key (str): prefix/name of the file in the destination bucket
    """
    logging.info('Copying Brazil data file ' \
                f'FROM: http://{origin_host}/{origin_filepath} ' \
                f'TO: s3://{dest_bucket}/{dest_key}')

    # Create a connection to the source server
    conn = Connection(conn_id='http_conn_brasilio',
                      conn_type='http',
                      host=origin_host,
                      port=80)  #create a connection object
    session = settings.Session()  # get the session
    session.add(conn)
    session.commit()

    # Get the data file
    http_hook = HttpHook(method='GET', http_conn_id='http_conn_brasilio')
    response_br_data = http_hook.run(origin_filepath)

    # Store data file into s3 bucket
    s3_hook = S3Hook(aws_conn_id='aws_default')
    s3_hook.load_bytes(response_br_data.content,
                       dest_key,
                       bucket_name=dest_bucket,
                       replace=True)

    logging.info('Data copy finished.')
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(conn_id='imap_test',
                    host='base_url',
                    login='******',
                    password='******'))
 def __init__(self, conn_id, variation: str):
     self.conn_id = conn_id
     self.var_name = "AIRFLOW_CONN_" + self.conn_id.upper()
     self.host = "host_{}.com".format(variation)
     self.conn_uri = ("mysql://*****:*****@" + self.host +
                      "/schema?extra1=val%2B1&extra2=val%2B2")
     self.conn = Connection(conn_id=self.conn_id, uri=self.conn_uri)
Пример #21
0
 def test_host_encoded_https_connection(self, mock_get_connection):
     c = Connection(conn_id='http_default', conn_type='http',
                    host='https://localhost')
     mock_get_connection.return_value = c
     hook = HttpHook()
     hook.get_conn({})
     self.assertEqual(hook.base_url, 'https://localhost')
Пример #22
0
 def test_deploy_function_function_not_exist(self, mock_get_connection, m):
     m.post("http://open-faas.io" + self.DEPLOY_FUNCTION,
            json={},
            status_code=202)
     mock_connection = Connection(host="http://open-faas.io")
     mock_get_connection.return_value = mock_connection
     self.assertEqual(self.hook.deploy_function(False, {}), None)
Пример #23
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(
             conn_id='spark-default', conn_type='spark',
             host='yarn://yarn-master',
             extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
     )
     db.merge_conn(
         Connection(
             conn_id='jdbc-default', conn_type='postgres',
             host='localhost', schema='default', port=5432,
             login='******', password='******',
             extra='{"conn_prefix":"jdbc:postgresql://"}'
         )
     )
Пример #24
0
def airflow_conns(database):
    """Create Airflow connections for testing.

    We create them by setting AIRFLOW_CONN_{CONN_ID} env variables. Only postgres
    connections are set for now as our testing database is postgres.
    """
    uris = (
        f"postgres://{database.user}:{database.password}@{database.host}:{database.port}/public?dbname={database.dbname}",
        f"postgres://{database.user}:{database.password}@{database.host}:{database.port}/public",
    )
    ids = (
        "dbt_test_postgres_1",
        database.dbname,
    )
    session = settings.Session()

    connections = []
    for conn_id, uri in zip(ids, uris):
        existing = session.query(Connection).filter_by(conn_id=conn_id).first()
        if existing is not None:
            # Connections may exist from previous test run.
            session.delete(existing)
            session.commit()
        connections.append(Connection(conn_id=conn_id, uri=uri))

    session.add_all(connections)

    session.commit()

    yield ids

    session.close()
Пример #25
0
 def setUp(self, mock_get_connection, mock_initialize):
     mock_get_connection.return_value = Connection(
         extra=json.dumps({
             'app_key': APP_KEY,
             'api_key': API_KEY,
         }))
     self.hook = DatadogHook()
Пример #26
0
def get_airflow_connection(conn_id=None):
    return Connection(
        conn_id='http_default',
        conn_type='http',
        host='test:8080/',
        extra='{"bareer": "test"}'
    )
Пример #27
0
def get_airflow_connection_with_port(conn_id=None):
    return Connection(
        conn_id='http_default',
        conn_type='http',
        host='test.com',
        port=1234
    )
Пример #28
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(
             conn_id='slack-webhook-default',
             extra='{"webhook_token": "your_token_here"}')
     )
    def setUp(self):

        configuration.load_test_config()
        db.merge_conn(
            Connection(conn_id='spark_default',
                       conn_type='spark',
                       host='yarn://yarn-master'))
Пример #30
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(
             conn_id='jira_default', conn_type='jira',
             host='https://localhost/jira/', port=443,
             extra='{"verify": "False", "project": "AIRFLOW"}'))
Пример #31
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(conn_id=self.conn_id,
                    conn_type='http',
                    host='https://oapi.dingtalk.com',
                    password='******'))
Пример #32
0
    def write_to_hdfs(rows: List[Tuple[str, str]]):
        conn: Connection = Connection.get_connection_from_secrets('local_hdfs')
        uri = conn.get_uri()
        pat = re.compile("http://(\w+(:\w+)?)?@")
        print(conn.get_uri())

        uri = pat.sub("http://", uri)
        print(uri)
        print(conn.login)
        client = InsecureClient(uri, user=conn.login)
        sch = avro.schema.make_avsc_object({
            'type':'record',
            'name':'Video',
            'fields': [
                {'type': {'type': 'string', 'avro.java.string': 'String'}, 'name': 'title'},
                {'type': ["null", {'type': 'string', 'avro.java.string': 'String'}], 'name': 'description'},
            ]
        })
        local_file_name = 'videos.avro'
        writer = DataFileWriter(open(local_file_name, "wb"), DatumWriter(), sch)
        for row in rows:
            print(row)
            writer.append({"title":row[0], "description":row[1]})
        writer.close()
        client.upload('/tmp/videos.avro', local_file_name)
Пример #33
0
    def deserialize_connection(self, conn_id: str, value: str) -> 'Connection':
        """
        Given a serialized representation of the airflow Connection, return an instance.
        Looks at first character to determine how to deserialize.

        :param conn_id: connection id
        :param value: the serialized representation of the Connection object
        :return: the deserialized Connection
        """
        from airflow.models.connection import Connection

        value = value.strip()
        if value[0] == '{':
            return Connection.from_json(conn_id=conn_id, value=value)
        else:
            return Connection(conn_id=conn_id, uri=value)
Пример #34
0
 def setUp(self):
     configuration.load_test_config()
     db.merge_conn(
         Connection(
             conn_id='jdbc_default', conn_type='jdbc',
             host='jdbc://localhost/', port=443,
             extra=json.dumps({"extra__jdbc__drv_path": "/path1/test.jar,/path2/t.jar2",
                               "extra__jdbc__drv_clsname": "com.driver.main"})))
    def test_cloudsql_hook_delete_connection_on_exception(
            self, get_connections, run, get_connection, delete_connection):
        connection = Connection()
        connection.parse_from_uri(
            "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=mysql&"
            "project_id=example-project&location=europe-west1&instance=testdb&"
            "use_proxy=False")
        get_connection.return_value = connection

        db_connection = Connection()
        db_connection.host = "127.0.0.1"
        db_connection.set_extra(json.dumps({"project_id": "example-project",
                                            "location": "europe-west1",
                                            "instance": "testdb",
                                            "database_type": "mysql"}))
        get_connections.return_value = [db_connection]
        run.side_effect = Exception("Exception when running a query")
        operator = CloudSqlQueryOperator(
            sql=['SELECT * FROM TABLE'],
            task_id='task_id'
        )
        with self.assertRaises(Exception) as cm:
            operator.execute(None)
        err = cm.exception
        self.assertEqual("Exception when running a query", str(err))
        delete_connection.assert_called_once_with()
Пример #36
0
 def test_create_operator_with_correct_parameters_postgres(self, get_connections):
     connection = Connection()
     connection.parse_from_uri(
         "gcpcloudsql://*****:*****@8.8.8.8:3200/testdb?database_type=postgres&"
         "project_id=example-project&location=europe-west1&instance=testdb&"
         "use_proxy=False&use_ssl=False")
     get_connections.return_value = [connection]
     operator = CloudSqlQueryOperator(
         sql=['SELECT * FROM TABLE'],
         task_id='task_id'
     )
     operator.cloudsql_db_hook.create_connection()
     try:
         db_hook = operator.cloudsql_db_hook.get_database_hook()
         conn = db_hook._get_connections_from_db(db_hook.postgres_conn_id)[0]
     finally:
         operator.cloudsql_db_hook.delete_connection()
     self.assertEqual('postgres', conn.conn_type)
     self.assertEqual('8.8.8.8', conn.host)
     self.assertEqual(3200, conn.port)
     self.assertEqual('testdb', conn.schema)
Пример #37
0
    def registered(self, driver, frameworkId, masterInfo):
        self.log.info("AirflowScheduler registered to Mesos with framework ID %s",
                      frameworkId.value)

        if configuration.conf.getboolean('mesos', 'CHECKPOINT') and \
                configuration.conf.get('mesos', 'FAILOVER_TIMEOUT'):
            # Import here to work around a circular import error
            from airflow.models.connection import Connection

            # Update the Framework ID in the database.
            session = Session()
            conn_id = FRAMEWORK_CONNID_PREFIX + get_framework_name()
            connection = Session.query(Connection).filter_by(conn_id=conn_id).first()
            if connection is None:
                connection = Connection(conn_id=conn_id, conn_type='mesos_framework-id',
                                        extra=frameworkId.value)
            else:
                connection.extra = frameworkId.value

            session.add(connection)
            session.commit()
            Session.remove()
Пример #38
0
    def test_connection_extra_with_encryption_rotate_fernet_key(self, mock_get):
        """
        Tests rotating encrypted extras.
        """
        key1 = Fernet.generate_key()
        key2 = Fernet.generate_key()

        mock_get.return_value = key1.decode()
        test_connection = Connection(extra='testextra')
        self.assertTrue(test_connection.is_extra_encrypted)
        self.assertEqual(test_connection.extra, 'testextra')
        self.assertEqual(Fernet(key1).decrypt(test_connection._extra.encode()), b'testextra')

        # Test decrypt of old value with new key
        mock_get.return_value = ','.join([key2.decode(), key1.decode()])
        models._fernet = None
        self.assertEqual(test_connection.extra, 'testextra')

        # Test decrypt of new value with new key
        test_connection.rotate_fernet_key()
        self.assertTrue(test_connection.is_extra_encrypted)
        self.assertEqual(test_connection.extra, 'testextra')
        self.assertEqual(Fernet(key2).decrypt(test_connection._extra.encode()), b'testextra')
 def _setup_connections(get_connections, uri):
     gcp_connection = mock.MagicMock()
     gcp_connection.extra_dejson = mock.MagicMock()
     gcp_connection.extra_dejson.get.return_value = 'empty_project'
     cloudsql_connection = Connection()
     cloudsql_connection.parse_from_uri(uri)
     cloudsql_connection2 = Connection()
     cloudsql_connection2.parse_from_uri(uri)
     get_connections.side_effect = [[gcp_connection], [cloudsql_connection],
                                    [cloudsql_connection2]]