예제 #1
0
def setup_db(conn_id):
    logger.info(f'mysql connection id : {conn_id}')

    # get credentials
    hook = MySqlHook(mysql_conn_id=conn_id, schema='crawler')
    uri = hook.get_uri()

    logger.info(f'uri {uri}')

    dal = DataAccessLayer()
    return dal.db_init(uri)
예제 #2
0
class TestMySqlHookConn(unittest.TestCase):
    def setUp(self):
        super(TestMySqlHookConn, self).setUp()

        self.connection = Connection(
            conn_type='mysql',
            login='******',
            password='******',
            host='host',
            schema='schema',
        )

        self.db_hook = MySqlHook()
        self.db_hook.get_connection = mock.Mock()
        self.db_hook.get_connection.return_value = self.connection

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn(self, mock_connect):
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['user'], 'login')
        self.assertEqual(kwargs['passwd'], 'password')
        self.assertEqual(kwargs['host'], 'host')
        self.assertEqual(kwargs['db'], 'schema')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_uri(self, mock_connect):
        self.connection.extra = json.dumps({'charset': 'utf-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(self.db_hook.get_uri(),
                         "mysql://*****:*****@host/schema?charset=utf-8")

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_port(self, mock_connect):
        self.connection.port = 3307
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['port'], 3307)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_charset(self, mock_connect):
        self.connection.extra = json.dumps({'charset': 'utf-8'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['charset'], 'utf-8')
        self.assertEqual(kwargs['use_unicode'], True)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_cursor(self, mock_connect):
        self.connection.extra = json.dumps({'cursor': 'sscursor'})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['cursorclass'], MySQLdb.cursors.SSCursor)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_local_infile(self, mock_connect):
        self.connection.extra = json.dumps({'local_infile': True})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['local_infile'], 1)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_con_unix_socket(self, mock_connect):
        self.connection.extra = json.dumps({'unix_socket': "/tmp/socket"})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['unix_socket'], '/tmp/socket')

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_dictionary(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': SSL_DICT})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    def test_get_conn_ssl_as_string(self, mock_connect):
        self.connection.extra = json.dumps({'ssl': json.dumps(SSL_DICT)})
        self.db_hook.get_conn()
        assert mock_connect.call_count == 1
        args, kwargs = mock_connect.call_args
        self.assertEqual(args, ())
        self.assertEqual(kwargs['ssl'], SSL_DICT)

    @mock.patch('airflow.hooks.mysql_hook.MySQLdb.connect')
    @mock.patch('airflow.contrib.hooks.aws_hook.AwsHook.get_client_type')
    def test_get_conn_rds_iam(self, mock_client, mock_connect):
        self.connection.extra = '{"iam":true}'
        mock_client.return_value.generate_db_auth_token.return_value = 'aws_token'
        self.db_hook.get_conn()
        mock_connect.assert_called_once_with(
            user='******',
            passwd='aws_token',
            host='host',
            db='schema',
            port=3306,
            read_default_group='enable-cleartext-plugin')
예제 #3
0
def create_tables_mysql(*args, **kwargs):
    mysql_hook = MySqlHook(mysql_conn_id="mysql_zomato")
    conn = mysql_hook.get_uri()
    print(conn)