def test_create_operator_with_wrong_parameters(self, project_id, location, instance_name, database_type, use_proxy, use_ssl, sql, message, get_connections): uri = \ "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?" \ "database_type={database_type}&" \ "project_id={project_id}&location={location}&instance={instance_name}&" \ "use_proxy={use_proxy}&use_ssl={use_ssl}".format( database_type=database_type, project_id=project_id, location=location, instance_name=instance_name, use_proxy=use_proxy, use_ssl=use_ssl) self._setup_connections(get_connections, uri) with self.assertRaises(AirflowException) as cm: op = CloudSqlQueryOperator( sql=sql, task_id='task_id' ) op.execute(None) err = cm.exception self.assertIn(message, str(err))
def test_cloudsql_hook_delete_connection_on_exception( self, get_connections, run, get_connection, delete_connection): connection = Connection() connection.parse_from_uri( "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=mysql&" "project_id=example-project&location=europe-west1&instance=testdb&" "use_proxy=False") get_connection.return_value = connection db_connection = Connection() db_connection.host = "127.0.0.1" db_connection.set_extra( json.dumps({ "project_id": "example-project", "location": "europe-west1", "instance": "testdb", "database_type": "mysql" })) get_connections.return_value = [db_connection] run.side_effect = Exception("Exception when running a query") operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') with self.assertRaises(Exception) as cm: operator.execute(None) err = cm.exception self.assertEqual("Exception when running a query", str(err)) delete_connection.assert_called_once_with()
def test_create_operator_with_too_long_unix_socket_path( self, get_connections): uri = "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&" \ "instance=" \ "test_db_with_long_name_a_bit_above" \ "_the_limit_of_UNIX_socket_asdadadasadasd&" \ "use_proxy=True&sql_proxy_use_tcp=False" self._setup_connections(get_connections, uri) operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') with self.assertRaises(AirflowException) as cm: operator.execute(None) err = cm.exception self.assertIn("The UNIX socket path length cannot exceed", str(err))
def test_create_operator_with_correct_parameters_mysql_ssl( self, get_connections): uri = "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=mysql&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=False&use_ssl=True&sslcert=/bin/bash&" \ "sslkey=/bin/bash&sslrootcert=/bin/bash" self._setup_connections(get_connections, uri) operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('mysql', conn.conn_type) self.assertEqual('127.0.0.1', conn.host) self.assertEqual(3200, conn.port) self.assertEqual('testdb', conn.schema) self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['cert']) self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['key']) self.assertEqual('/bin/bash', json.loads(conn.extra_dejson['ssl'])['ca'])
def test_create_operator_with_correct_parameters_mysql_tcp( self, get_connections): uri = "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=mysql&" \ "project_id=example-project&location=europe-west1&instance=testdb&" \ "use_proxy=True&sql_proxy_use_tcp=True" self._setup_connections(get_connections, uri) operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db(db_hook.mysql_conn_id)[0] # pylint: disable=no-member finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('mysql', conn.conn_type) self.assertEqual('127.0.0.1', conn.host) self.assertNotEqual(3200, conn.port) self.assertEqual('testdb', conn.schema)
def test_create_operator_with_not_too_long_unix_socket_path( self, get_connections): uri = "gcpcloudsql://*****:*****@127.0.0.1:3200/testdb?database_type=postgres&" \ "project_id=example-project&location=europe-west1&" \ "instance=" \ "test_db_with_longname_but_with_limit_of_UNIX_socket&" \ "use_proxy=True&sql_proxy_use_tcp=False" self._setup_connections(get_connections, uri) operator = CloudSqlQueryOperator(sql=['SELECT * FROM TABLE'], task_id='task_id') operator.cloudsql_db_hook.create_connection() try: db_hook = operator.cloudsql_db_hook.get_database_hook() conn = db_hook._get_connections_from_db( db_hook.postgres_conn_id)[0] # pylint: disable=no-member finally: operator.cloudsql_db_hook.delete_connection() self.assertEqual('postgres', conn.conn_type) self.assertEqual('testdb', conn.schema)
# [END howto_operator_cloudsql_query_connections] # [START howto_operator_cloudsql_query_operators] connection_names = [ "proxy_postgres_tcp", "proxy_postgres_socket", "public_postgres_tcp", "public_postgres_tcp_ssl", "proxy_mysql_tcp", "proxy_mysql_socket", "public_mysql_tcp", "public_mysql_tcp_ssl", "public_mysql_tcp_ssl_no_project_id" ] tasks = [] with models.DAG(dag_id='example_gcp_sql_query', default_args=default_args, schedule_interval=None) as dag: prev_task = None for connection_name in connection_names: task = CloudSqlQueryOperator(gcp_cloudsql_conn_id=connection_name, task_id="example_gcp_sql_task_" + connection_name, sql=SQL) tasks.append(task) if prev_task: prev_task >> task prev_task = task # [END howto_operator_cloudsql_query_operators]