def setup_db(conn_id): logger.info(f'mysql connection id : {conn_id}') # get credentials hook = MySqlHook(mysql_conn_id=conn_id, schema='crawler') uri = hook.get_uri() logger.info(f'uri {uri}') dal = DataAccessLayer() return dal.db_init(uri)
class TestMySqlHookConn(unittest.TestCase): def setUp(self): super(TestMySqlHookConn, self).setUp() self.connection = Connection( conn_type='mysql', login='******', password='******', host='host', schema='schema', ) self.db_hook = MySqlHook() self.db_hook.get_connection = mock.Mock() self.db_hook.get_connection.return_value = self.connection @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn(self, mock_connect): self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['user'], 'login') self.assertEqual(kwargs['passwd'], 'password') self.assertEqual(kwargs['host'], 'host') self.assertEqual(kwargs['db'], 'schema') @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_uri(self, mock_connect): self.connection.extra = json.dumps({'charset': 'utf-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(self.db_hook.get_uri(), "mysql://*****:*****@host/schema?charset=utf-8") @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_port(self, mock_connect): self.connection.port = 3307 self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['port'], 3307) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_charset(self, mock_connect): self.connection.extra = json.dumps({'charset': 'utf-8'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['charset'], 'utf-8') self.assertEqual(kwargs['use_unicode'], True) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_cursor(self, mock_connect): self.connection.extra = json.dumps({'cursor': 'sscursor'}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_local_infile(self, mock_connect): self.connection.extra = json.dumps({'local_infile': True}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['local_infile'], 1) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_con_unix_socket(self, mock_connect): self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['unix_socket'], '/tmp/socket') @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_ssl_as_dictionary(self, mock_connect): self.connection.extra = json.dumps({'ssl': SSL_DICT}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['ssl'], SSL_DICT) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') def test_get_conn_ssl_as_string(self, mock_connect): self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)}) self.db_hook.get_conn() assert mock_connect.call_count == 1 args, kwargs = mock_connect.call_args self.assertEqual(args, ()) self.assertEqual(kwargs['ssl'], SSL_DICT) @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect') @mock.patch('airflow.contrib.hooks.aws_hook.AwsHook.get_client_type') def test_get_conn_rds_iam(self, mock_client, mock_connect): self.connection.extra = '{"iam":true}' mock_client.return_value.generate_db_auth_token.return_value = 'aws_token' self.db_hook.get_conn() mock_connect.assert_called_once_with( user='******', passwd='aws_token', host='host', db='schema', port=3306, read_default_group='enable-cleartext-plugin')
def create_tables_mysql(*args, **kwargs): mysql_hook = MySqlHook(mysql_conn_id="mysql_zomato") conn = mysql_hook.get_uri() print(conn)