示例#1
0
文件: tests.py 项目: n040661/Archery
 def test_query_masking(self):
     query_result = ResultSet()
     new_engine = PgSQLEngine(instance=self.ins)
     masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result)
     self.assertEqual(masking_result, query_result)
示例#2
0
 def test_query_masking_not_select(self, _data_masking):
     query_result = ResultSet()
     new_engine = MysqlEngine(instance=self.ins1)
     masking_result = new_engine.query_masking(db_name='archery', sql='explain select 1', resultset=query_result)
     self.assertEqual(masking_result, query_result)
示例#3
0
文件: tests.py 项目: n040661/Archery
class TestPgSQL(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.ins = Instance(instance_name='some_ins', type='slave', db_type='pgsql', host='some_host',
                           port=1366, user='******', password='******')
        cls.ins.save()

    @classmethod
    def tearDownClass(cls):
        cls.ins.delete()

    @patch('psycopg2.connect')
    def test_get_connection(self, _conn):
        new_engine = PgSQLEngine(instance=self.ins)
        new_engine.get_connection()
        _conn.assert_called_once()

    @patch('psycopg2.connect.cursor.execute')
    @patch('psycopg2.connect.cursor')
    @patch('psycopg2.connect')
    def test_query(self, _conn, _cursor, _execute):
        _conn.return_value.cursor.return_value.fetchmany.return_value = [(1,)]
        new_engine = PgSQLEngine(instance=self.ins)
        query_result = new_engine.query(db_name=0, sql='select 1', limit_num=100)
        self.assertIsInstance(query_result, ResultSet)
        self.assertListEqual(query_result.rows, [(1,)])

    @patch('sql.engines.pgsql.PgSQLEngine.query',
           return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)]))
    def test_get_all_databases(self, _query):
        new_engine = PgSQLEngine(instance=self.ins)
        dbs = new_engine.get_all_databases()
        self.assertListEqual(dbs, ['archery'])

    @patch('sql.engines.pgsql.PgSQLEngine.query',
           return_value=ResultSet(rows=[('information_schema',), ('archery',), ('pg_catalog',)]))
    def test_get_all_schemas(self, _query):
        new_engine = PgSQLEngine(instance=self.ins)
        schemas = new_engine.get_all_schemas(db_name='archery')
        self.assertListEqual(schemas, ['archery'])

    @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('test',), ('test2',)]))
    def test_get_all_tables(self, _query):
        new_engine = PgSQLEngine(instance=self.ins)
        tables = new_engine.get_all_tables(db_name='archery', schema_name='archery')
        self.assertListEqual(tables, ['test2'])

    @patch('sql.engines.pgsql.PgSQLEngine.query',
           return_value=ResultSet(rows=[('id',), ('name',)]))
    def test_get_all_columns_by_tb(self, _query):
        new_engine = PgSQLEngine(instance=self.ins)
        columns = new_engine.get_all_columns_by_tb(db_name='archery', tb_name='test2', schema_name='archery')
        self.assertListEqual(columns, ['id', 'name'])

    @patch('sql.engines.pgsql.PgSQLEngine.query',
           return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)]))
    def test_describe_table(self, _query):
        new_engine = PgSQLEngine(instance=self.ins)
        describe = new_engine.describe_table(db_name='archery', schema_name='archery',tb_name='text')
        self.assertIsInstance(describe, ResultSet)

    def test_query_check_disable_sql(self):
        sql = "update xxx set a=1 "
        new_engine = PgSQLEngine(instance=self.ins)
        check_result = new_engine.query_check(db_name='archery', sql=sql)
        self.assertDictEqual(check_result,
                             {'msg': '不止的支持查询语法类型!', 'bad_query': True, 'filtered_sql': sql.strip(), 'has_star': False})

    def test_query_check_star_sql(self):
        sql = "select * from xx "
        new_engine = PgSQLEngine(instance=self.ins)
        check_result = new_engine.query_check(db_name='archery', sql=sql)
        self.assertDictEqual(check_result,
                             {'msg': 'SQL语句中含有 * ', 'bad_query': False, 'filtered_sql': sql.strip(), 'has_star': True})

    def test_filter_sql_with_delimiter(self):
        sql = "select * from xx;"
        new_engine = PgSQLEngine(instance=self.ins)
        check_result = new_engine.filter_sql(sql=sql, limit_num=100)
        self.assertEqual(check_result, "select * from xx limit 100;")

    def test_filter_sql_without_delimiter(self):
        sql = "select * from xx"
        new_engine = PgSQLEngine(instance=self.ins)
        check_result = new_engine.filter_sql(sql=sql, limit_num=100)
        self.assertEqual(check_result, "select * from xx limit 100;")

    def test_query_masking(self):
        query_result = ResultSet()
        new_engine = PgSQLEngine(instance=self.ins)
        masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result)
        self.assertEqual(masking_result, query_result)
示例#4
0
class TestMysql(TestCase):

    def setUp(self):
        self.ins1 = Instance(instance_name='some_ins', type='slave', db_type='mysql', host='some_host',
                             port=1366, user='******', password='******')
        self.ins1.save()
        self.sys_config = SysConfig()
        self.wf = SqlWorkflow.objects.create(
            workflow_name='some_name',
            group_id=1,
            group_name='g1',
            engineer_display='',
            audit_auth_groups='some_group',
            create_time=datetime.now() - timedelta(days=1),
            status='workflow_finish',
            is_backup=True,
            instance=self.ins1,
            db_name='some_db',
            syntax_type=1
        )
        SqlWorkflowContent.objects.create(workflow=self.wf)

    def tearDown(self):
        self.ins1.delete()
        self.sys_config.replace(json.dumps({}))
        SqlWorkflow.objects.all().delete()
        SqlWorkflowContent.objects.all().delete()

    @patch('MySQLdb.connect')
    def testGetConnection(self, connect):
        new_engine = MysqlEngine(instance=self.ins1)
        new_engine.get_connection()
        connect.assert_called_once()

    @patch('MySQLdb.connect')
    def testQuery(self, connect):
        cur = Mock()
        connect.return_value.cursor = cur
        cur.return_value.execute = Mock()
        cur.return_value.fetchmany.return_value = (('v1', 'v2'),)
        cur.return_value.description = (('k1', 'some_other_des'), ('k2', 'some_other_des'))
        new_engine = MysqlEngine(instance=self.ins1)
        query_result = new_engine.query(sql='some_str', limit_num=100)
        cur.return_value.execute.assert_called()
        cur.return_value.fetchmany.assert_called_once_with(size=100)
        connect.return_value.close.assert_called_once()
        self.assertIsInstance(query_result, ResultSet)

    @patch.object(MysqlEngine, 'query')
    def testAllDb(self, mock_query):
        db_result = ResultSet()
        db_result.rows = [('db_1',), ('db_2',)]
        mock_query.return_value = db_result
        new_engine = MysqlEngine(instance=self.ins1)
        dbs = new_engine.get_all_databases()
        self.assertEqual(dbs.rows, ['db_1', 'db_2'])

    @patch.object(MysqlEngine, 'query')
    def testAllTables(self, mock_query):
        table_result = ResultSet()
        table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')]
        mock_query.return_value = table_result
        new_engine = MysqlEngine(instance=self.ins1)
        tables = new_engine.get_all_tables('some_db')
        mock_query.assert_called_once_with(db_name='some_db', sql=ANY)
        self.assertEqual(tables.rows, ['tb_1', 'tb_2'])

    @patch.object(MysqlEngine, 'query')
    def testAllColumns(self, mock_query):
        db_result = ResultSet()
        db_result.rows = [('col_1', 'type'), ('col_2', 'type2')]
        mock_query.return_value = db_result
        new_engine = MysqlEngine(instance=self.ins1)
        dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb')
        self.assertEqual(dbs.rows, ['col_1', 'col_2'])

    @patch.object(MysqlEngine, 'query')
    def testDescribe(self, mock_query):
        new_engine = MysqlEngine(instance=self.ins1)
        new_engine.describe_table('some_db', 'some_db')
        mock_query.assert_called_once()

    def testQueryCheck(self):
        new_engine = MysqlEngine(instance=self.ins1)
        sql_without_limit = '-- 测试\n select user from usertable'
        check_result = new_engine.query_check(db_name='some_db', sql=sql_without_limit)
        self.assertEqual(check_result['filtered_sql'], 'select user from usertable')

    def test_query_check_wrong_sql(self):
        new_engine = MysqlEngine(instance=self.ins1)
        wrong_sql = '-- 测试'
        check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql)
        self.assertDictEqual(check_result,
                             {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False})

    def test_query_check_update_sql(self):
        new_engine = MysqlEngine(instance=self.ins1)
        update_sql = 'update user set id=0'
        check_result = new_engine.query_check(db_name='some_db', sql=update_sql)
        self.assertDictEqual(check_result,
                             {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': 'update user set id=0',
                              'has_star': False})

    def test_filter_sql_with_delimiter(self):
        new_engine = MysqlEngine(instance=self.ins1)
        sql_without_limit = 'select user from usertable;'
        check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
        self.assertEqual(check_result, 'select user from usertable limit 100;')

    def test_filter_sql_without_delimiter(self):
        new_engine = MysqlEngine(instance=self.ins1)
        sql_without_limit = 'select user from usertable'
        check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100)
        self.assertEqual(check_result, 'select user from usertable limit 100;')

    def test_filter_sql_with_limit(self):
        new_engine = MysqlEngine(instance=self.ins1)
        sql_without_limit = 'select user from usertable limit 10'
        check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1)
        self.assertEqual(check_result, 'select user from usertable limit 10;')

    @patch('sql.engines.mysql.data_masking', return_value=ResultSet())
    def test_query_masking(self, _data_masking):
        query_result = ResultSet()
        new_engine = MysqlEngine(instance=self.ins1)
        masking_result = new_engine.query_masking(db_name='archery', sql='select 1', resultset=query_result)
        self.assertIsInstance(masking_result, ResultSet)

    @patch('sql.engines.mysql.data_masking', return_value=ResultSet())
    def test_query_masking_not_select(self, _data_masking):
        query_result = ResultSet()
        new_engine = MysqlEngine(instance=self.ins1)
        masking_result = new_engine.query_masking(db_name='archery', sql='explain select 1', resultset=query_result)
        self.assertEqual(masking_result, query_result)

    def test_execute_check_select_sql(self):
        sql = 'select * from user'
        row = ReviewResult(id=1, errlevel=2,
                           stagestatus='驳回高危SQL',
                           errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
                           sql=sql)
        new_engine = MysqlEngine(instance=self.ins1)
        check_result = new_engine.execute_check(db_name='archery', sql=sql)
        self.assertIsInstance(check_result, ReviewSet)
        self.assertEqual(check_result.rows[0].__dict__, row.__dict__)

    def test_execute_check_critical_sql(self):
        self.sys_config.set('critical_ddl_regex', '^|update')
        self.sys_config.get_all_config()
        sql = 'update user set id=1'
        row = ReviewResult(id=1, errlevel=2,
                           stagestatus='驳回高危SQL',
                           errormessage='禁止提交匹配' + '^|update' + '条件的语句!',
                           sql=sql)
        new_engine = MysqlEngine(instance=self.ins1)
        check_result = new_engine.execute_check(db_name='archery', sql=sql)
        self.assertIsInstance(check_result, ReviewSet)
        self.assertEqual(check_result.rows[0].__dict__, row.__dict__)

    @patch('sql.engines.mysql.InceptionEngine')
    def test_execute_check_normal_sql(self, _inception_engine):
        sql = 'update user set id=1'
        row = ReviewResult(id=1,
                           errlevel=0,
                           stagestatus='Audit completed',
                           errormessage='None',
                           sql=sql,
                           affected_rows=0,
                           execute_time=0, )
        _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[row])
        new_engine = MysqlEngine(instance=self.ins1)
        check_result = new_engine.execute_check(db_name='archery', sql=sql)
        self.assertIsInstance(check_result, ReviewSet)
        self.assertEqual(check_result.rows[0].__dict__, row.__dict__)

    @patch('sql.engines.mysql.InceptionEngine')
    def test_execute_check_normal_sql_with_Exception(self, _inception_engine):
        sql = 'update user set id=1'
        _inception_engine.return_value.execute_check.side_effect = RuntimeError()
        new_engine = MysqlEngine(instance=self.ins1)
        with self.assertRaises(RuntimeError):
            new_engine.execute_check(db_name=0, sql=sql)

    @patch('sql.engines.mysql.InceptionEngine')
    def test_execute_workflow(self, _inception_engine):
        sql = 'update user set id=1'
        _inception_engine.return_value.execute.return_value = ReviewSet(full_sql=sql)
        new_engine = MysqlEngine(instance=self.ins1)
        execute_result = new_engine.execute_workflow(self.wf)
        self.assertIsInstance(execute_result, ReviewSet)

    @patch('MySQLdb.connect.cursor.execute')
    @patch('MySQLdb.connect.cursor')
    @patch('MySQLdb.connect')
    def test_execute(self, _connect, _cursor, _execute):
        new_engine = MysqlEngine(instance=self.ins1)
        execute_result = new_engine.execute(self.wf)
        self.assertIsInstance(execute_result, ResultSet)

    @patch.object(MysqlEngine, 'query')
    def test_server_version(self, _query):
        _query.return_value.rows = (('5.7.20',),)
        new_engine = MysqlEngine(instance=self.ins1)
        server_version = new_engine.server_version
        self.assertTupleEqual(server_version, (5, 7, 20))

    @patch.object(MysqlEngine, 'query')
    def test_get_variables_not_filter(self, _query):
        new_engine = MysqlEngine(instance=self.ins1)
        new_engine.get_variables()
        _query.assert_called_once()

    @patch.object(MysqlEngine, 'query')
    def test_get_variables_filter(self, _query):
        new_engine = MysqlEngine(instance=self.ins1)
        new_engine.get_variables(variables=['binlog_format'])
        _query.assert_called()

    @patch.object(MysqlEngine, 'query')
    def test_set_variable(self, _query):
        new_engine = MysqlEngine(instance=self.ins1)
        new_engine.set_variable('binlog_format', 'ROW')
        _query.assert_called_once()
示例#5
0
 def test_query_masking(self, _data_masking):
     query_result = ResultSet()
     new_engine = MysqlEngine(instance=self.ins1)
     masking_result = new_engine.query_masking(db_name='archery', sql='select 1', resultset=query_result)
     self.assertIsInstance(masking_result, ResultSet)
示例#6
0
 def describe_table(self, db_name, tb_name):
     """获取表结构, 返回一个 ResultSet,rows=list"""
     return ResultSet()
示例#7
0
def query(request):
    """
    获取SQL查询结果
    :param request:
    :return:
    """
    instance_name = request.POST.get('instance_name')
    sql_content = request.POST.get('sql_content')
    db_name = request.POST.get('db_name')
    limit_num = request.POST.get('limit_num')
    user = request.user

    result = {'status': 0, 'msg': 'ok', 'data': {}}
    try:
        instance = Instance.objects.get(instance_name=instance_name)
    except Instance.DoesNotExist:
        result['status'] = 1
        result['msg'] = '实例不存在'
        return result

    # 服务器端参数验证
    if sql_content is None or db_name is None or instance_name is None or limit_num is None:
        result['status'] = 1
        result['msg'] = '页面提交参数可能为空'
        return HttpResponse(json.dumps(result), content_type='application/json')

    # 删除注释语句,进行语法判断
    sql_content = sqlparse.format(sql_content.strip(), strip_comments=True)
    sql_list = sqlparse.split(sql_content)
    # 执行第一条有效sql
    sql_content = sql_list[0].rstrip(';')
    if re.match(r"^select|^show|^explain", sql_content, re.I) is None:
        result['status'] = 1
        result['msg'] = '仅支持^select|^show|^explain语法,请联系管理员!'
        return HttpResponse(json.dumps(result), content_type='application/json')

    try:
        # 查询权限校验
        priv_check_info = query_priv_check(user, instance_name, db_name, sql_content, limit_num)
        if priv_check_info['status'] == 0:
            limit_num = priv_check_info['data']['limit_num']
            priv_check = priv_check_info['data']['priv_check']
        else:
            result['status'] = priv_check_info['status']
            result['msg'] = priv_check_info['msg']
            data = ResultSet(full_sql=sql_content)
            data.error = priv_check_info['msg']
            result['data'] = data.__dict__
            return HttpResponse(json.dumps(result), content_type='application/json')
        limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num

        # 查询检查
        query_engine = get_engine(instance=instance)
        filter_result = query_engine.query_check(db_name=db_name, sql=sql_content, limit_num=limit_num)
        if filter_result.get('bad_query'):
            # 引擎内部判断为 bad_query
            result['status'] = 1
            result['msg'] = filter_result.get('msg')
            return HttpResponse(json.dumps(result), content_type='application/json')
        if filter_result.get('has_star') and SysConfig().get('disable_star') is True:
            # 引擎内部判断为有 * 且禁止 * 选项打开
            result['status'] = 1
            result['msg'] = filter_result.get('msg')
            return HttpResponse(json.dumps(result), content_type='application/json')
        else:
            sql_content = filter_result['filtered_sql']
        sql_content = sql_content + ';'

        # 执行查询语句,统计执行时间
        t_start = time.time()
        query_result = query_engine.query(db_name=str(db_name), sql=sql_content, limit_num=limit_num)
        t_end = time.time()
        query_result.query_time = "%5s" % "{:.4f}".format(t_end - t_start)

        # 数据脱敏,同样需要检查配置,是否开启脱敏,语法树解析是否允许出错继续执行
        hit_rule = 0 if re.match(r"^select", sql_content.lower()) else 2  # 查询是否命中脱敏规则,0, '未知', 1, '命中', 2, '未命中'
        masking = 2  # 查询结果是否正常脱敏,1, '是', 2, '否'
        t_start = time.time()
        # 仅对正确查询的语句进行脱敏
        if SysConfig().get('data_masking') and re.match(r"^select", sql_content.lower()) and query_result.error is None:
            try:
                query_result = query_engine.query_masking(db_name=db_name, sql=sql_content, resultset=query_result)
                if query_result.is_critical is True and SysConfig().get('query_check'):
                    masking_result = {'status': query_result.status,
                                      'msg': query_result.error,
                                      'data': query_result.__dict__}
                    return HttpResponse(json.dumps(masking_result), content_type='application/json')
                else:
                    # 重置脱敏结果,返回未脱敏数据
                    query_result.status = 0
                    query_result.error = None
                    # 实际未命中, 则显示为未做脱敏
                    if query_result.is_masked:
                        masking = 1
                        hit_rule = 1
            except Exception:
                logger.error(traceback.format_exc())
                # 报错, 未脱敏, 未命中
                hit_rule = 2
                masking = 2
                if SysConfig().get('query_check'):
                    result['status'] = 1
                    result['msg'] = '脱敏数据报错,请联系管理员'
                    return HttpResponse(json.dumps(result), content_type='application/json')

        t_end = time.time()
        query_result.mask_time = "%5s" % "{:.4f}".format(t_end - t_start)
        sql_result = query_result.__dict__

        result['data'] = sql_result

        # 成功的查询语句记录存入数据库
        if sql_result.get('error'):
            pass
        else:
            if int(limit_num) == 0:
                limit_num = int(sql_result['affected_rows'])
            else:
                limit_num = min(int(limit_num), int(sql_result['affected_rows']))
            query_log = QueryLog(
                username=user.username,
                user_display=user.display,
                db_name=db_name,
                instance_name=instance.instance_name,
                sqllog=sql_content,
                effect_row=limit_num,
                cost_time=query_result.query_time,
                priv_check=priv_check,
                hit_rule=hit_rule,
                masking=masking
            )
            # 防止查询超时
            try:
                query_log.save()
            except:
                connection.close()
                query_log.save()
    except Exception as e:
        logger.error(traceback.format_exc())
        result['status'] = 1
        result['msg'] = str(e)

    # 返回查询结果
    try:
        return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True),
                            content_type='application/json')
    except Exception:
        return HttpResponse(json.dumps(result, default=str, bigint_as_string=True, encoding='latin1'),
                            content_type='application/json')
示例#8
0
 def get_all_columns_by_tb(self, db_name, tb_name):
     """获取所有字段, 返回一个ResultSet,rows=list"""
     return ResultSet()
示例#9
0
 def get_all_tables(self, db_name):
     """获取table 列表, 返回一个ResultSet,rows=list"""
     return ResultSet()
示例#10
0
 def get_all_databases(self):
     """获取数据库列表, 返回一个ResultSet,rows=list"""
     return ResultSet()
示例#11
0
 def set_variable(self, variable_name, variable_value):
     """修改实例参数值,返回一个 ResultSet"""
     return ResultSet()
示例#12
0
 def get_variables(self, variables=None):
     """获取实例参数,返回一个 ResultSet"""
     return ResultSet()