def test_query_masking(self): query_result = ResultSet() new_engine = PgSQLEngine(instance=self.ins) masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result) self.assertEqual(masking_result, query_result)
def test_query_masking_not_select(self, _data_masking): query_result = ResultSet() new_engine = MysqlEngine(instance=self.ins1) masking_result = new_engine.query_masking(db_name='archery', sql='explain select 1', resultset=query_result) self.assertEqual(masking_result, query_result)
class TestPgSQL(TestCase): @classmethod def setUpClass(cls): cls.ins = Instance(instance_name='some_ins', type='slave', db_type='pgsql', host='some_host', port=1366, user='******', password='******') cls.ins.save() @classmethod def tearDownClass(cls): cls.ins.delete() @patch('psycopg2.connect') def test_get_connection(self, _conn): new_engine = PgSQLEngine(instance=self.ins) new_engine.get_connection() _conn.assert_called_once() @patch('psycopg2.connect.cursor.execute') @patch('psycopg2.connect.cursor') @patch('psycopg2.connect') def test_query(self, _conn, _cursor, _execute): _conn.return_value.cursor.return_value.fetchmany.return_value = [(1,)] new_engine = PgSQLEngine(instance=self.ins) query_result = new_engine.query(db_name=0, sql='select 1', limit_num=100) self.assertIsInstance(query_result, ResultSet) self.assertListEqual(query_result.rows, [(1,)]) @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)])) def test_get_all_databases(self, _query): new_engine = PgSQLEngine(instance=self.ins) dbs = new_engine.get_all_databases() self.assertListEqual(dbs, ['archery']) @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('information_schema',), ('archery',), ('pg_catalog',)])) def test_get_all_schemas(self, _query): new_engine = PgSQLEngine(instance=self.ins) schemas = new_engine.get_all_schemas(db_name='archery') self.assertListEqual(schemas, ['archery']) @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('test',), ('test2',)])) def test_get_all_tables(self, _query): new_engine = PgSQLEngine(instance=self.ins) tables = new_engine.get_all_tables(db_name='archery', schema_name='archery') self.assertListEqual(tables, ['test2']) @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('id',), ('name',)])) def test_get_all_columns_by_tb(self, _query): new_engine = PgSQLEngine(instance=self.ins) columns = new_engine.get_all_columns_by_tb(db_name='archery', tb_name='test2', schema_name='archery') self.assertListEqual(columns, ['id', 'name']) @patch('sql.engines.pgsql.PgSQLEngine.query', return_value=ResultSet(rows=[('postgres',), ('archery',), ('template1',), ('template0',)])) def test_describe_table(self, _query): new_engine = PgSQLEngine(instance=self.ins) describe = new_engine.describe_table(db_name='archery', schema_name='archery',tb_name='text') self.assertIsInstance(describe, ResultSet) def test_query_check_disable_sql(self): sql = "update xxx set a=1 " new_engine = PgSQLEngine(instance=self.ins) check_result = new_engine.query_check(db_name='archery', sql=sql) self.assertDictEqual(check_result, {'msg': '不止的支持查询语法类型!', 'bad_query': True, 'filtered_sql': sql.strip(), 'has_star': False}) def test_query_check_star_sql(self): sql = "select * from xx " new_engine = PgSQLEngine(instance=self.ins) check_result = new_engine.query_check(db_name='archery', sql=sql) self.assertDictEqual(check_result, {'msg': 'SQL语句中含有 * ', 'bad_query': False, 'filtered_sql': sql.strip(), 'has_star': True}) def test_filter_sql_with_delimiter(self): sql = "select * from xx;" new_engine = PgSQLEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=sql, limit_num=100) self.assertEqual(check_result, "select * from xx limit 100;") def test_filter_sql_without_delimiter(self): sql = "select * from xx" new_engine = PgSQLEngine(instance=self.ins) check_result = new_engine.filter_sql(sql=sql, limit_num=100) self.assertEqual(check_result, "select * from xx limit 100;") def test_query_masking(self): query_result = ResultSet() new_engine = PgSQLEngine(instance=self.ins) masking_result = new_engine.query_masking(db_name=0, sql='', resultset=query_result) self.assertEqual(masking_result, query_result)
class TestMysql(TestCase): def setUp(self): self.ins1 = Instance(instance_name='some_ins', type='slave', db_type='mysql', host='some_host', port=1366, user='******', password='******') self.ins1.save() self.sys_config = SysConfig() self.wf = SqlWorkflow.objects.create( workflow_name='some_name', group_id=1, group_name='g1', engineer_display='', audit_auth_groups='some_group', create_time=datetime.now() - timedelta(days=1), status='workflow_finish', is_backup=True, instance=self.ins1, db_name='some_db', syntax_type=1 ) SqlWorkflowContent.objects.create(workflow=self.wf) def tearDown(self): self.ins1.delete() self.sys_config.replace(json.dumps({})) SqlWorkflow.objects.all().delete() SqlWorkflowContent.objects.all().delete() @patch('MySQLdb.connect') def testGetConnection(self, connect): new_engine = MysqlEngine(instance=self.ins1) new_engine.get_connection() connect.assert_called_once() @patch('MySQLdb.connect') def testQuery(self, connect): cur = Mock() connect.return_value.cursor = cur cur.return_value.execute = Mock() cur.return_value.fetchmany.return_value = (('v1', 'v2'),) cur.return_value.description = (('k1', 'some_other_des'), ('k2', 'some_other_des')) new_engine = MysqlEngine(instance=self.ins1) query_result = new_engine.query(sql='some_str', limit_num=100) cur.return_value.execute.assert_called() cur.return_value.fetchmany.assert_called_once_with(size=100) connect.return_value.close.assert_called_once() self.assertIsInstance(query_result, ResultSet) @patch.object(MysqlEngine, 'query') def testAllDb(self, mock_query): db_result = ResultSet() db_result.rows = [('db_1',), ('db_2',)] mock_query.return_value = db_result new_engine = MysqlEngine(instance=self.ins1) dbs = new_engine.get_all_databases() self.assertEqual(dbs.rows, ['db_1', 'db_2']) @patch.object(MysqlEngine, 'query') def testAllTables(self, mock_query): table_result = ResultSet() table_result.rows = [('tb_1', 'some_des'), ('tb_2', 'some_des')] mock_query.return_value = table_result new_engine = MysqlEngine(instance=self.ins1) tables = new_engine.get_all_tables('some_db') mock_query.assert_called_once_with(db_name='some_db', sql=ANY) self.assertEqual(tables.rows, ['tb_1', 'tb_2']) @patch.object(MysqlEngine, 'query') def testAllColumns(self, mock_query): db_result = ResultSet() db_result.rows = [('col_1', 'type'), ('col_2', 'type2')] mock_query.return_value = db_result new_engine = MysqlEngine(instance=self.ins1) dbs = new_engine.get_all_columns_by_tb('some_db', 'some_tb') self.assertEqual(dbs.rows, ['col_1', 'col_2']) @patch.object(MysqlEngine, 'query') def testDescribe(self, mock_query): new_engine = MysqlEngine(instance=self.ins1) new_engine.describe_table('some_db', 'some_db') mock_query.assert_called_once() def testQueryCheck(self): new_engine = MysqlEngine(instance=self.ins1) sql_without_limit = '-- 测试\n select user from usertable' check_result = new_engine.query_check(db_name='some_db', sql=sql_without_limit) self.assertEqual(check_result['filtered_sql'], 'select user from usertable') def test_query_check_wrong_sql(self): new_engine = MysqlEngine(instance=self.ins1) wrong_sql = '-- 测试' check_result = new_engine.query_check(db_name='some_db', sql=wrong_sql) self.assertDictEqual(check_result, {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': '-- 测试', 'has_star': False}) def test_query_check_update_sql(self): new_engine = MysqlEngine(instance=self.ins1) update_sql = 'update user set id=0' check_result = new_engine.query_check(db_name='some_db', sql=update_sql) self.assertDictEqual(check_result, {'msg': '不支持的查询语法类型!', 'bad_query': True, 'filtered_sql': 'update user set id=0', 'has_star': False}) def test_filter_sql_with_delimiter(self): new_engine = MysqlEngine(instance=self.ins1) sql_without_limit = 'select user from usertable;' check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) self.assertEqual(check_result, 'select user from usertable limit 100;') def test_filter_sql_without_delimiter(self): new_engine = MysqlEngine(instance=self.ins1) sql_without_limit = 'select user from usertable' check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=100) self.assertEqual(check_result, 'select user from usertable limit 100;') def test_filter_sql_with_limit(self): new_engine = MysqlEngine(instance=self.ins1) sql_without_limit = 'select user from usertable limit 10' check_result = new_engine.filter_sql(sql=sql_without_limit, limit_num=1) self.assertEqual(check_result, 'select user from usertable limit 10;') @patch('sql.engines.mysql.data_masking', return_value=ResultSet()) def test_query_masking(self, _data_masking): query_result = ResultSet() new_engine = MysqlEngine(instance=self.ins1) masking_result = new_engine.query_masking(db_name='archery', sql='select 1', resultset=query_result) self.assertIsInstance(masking_result, ResultSet) @patch('sql.engines.mysql.data_masking', return_value=ResultSet()) def test_query_masking_not_select(self, _data_masking): query_result = ResultSet() new_engine = MysqlEngine(instance=self.ins1) masking_result = new_engine.query_masking(db_name='archery', sql='explain select 1', resultset=query_result) self.assertEqual(masking_result, query_result) def test_execute_check_select_sql(self): sql = 'select * from user' row = ReviewResult(id=1, errlevel=2, stagestatus='驳回高危SQL', errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!', sql=sql) new_engine = MysqlEngine(instance=self.ins1) check_result = new_engine.execute_check(db_name='archery', sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) def test_execute_check_critical_sql(self): self.sys_config.set('critical_ddl_regex', '^|update') self.sys_config.get_all_config() sql = 'update user set id=1' row = ReviewResult(id=1, errlevel=2, stagestatus='驳回高危SQL', errormessage='禁止提交匹配' + '^|update' + '条件的语句!', sql=sql) new_engine = MysqlEngine(instance=self.ins1) check_result = new_engine.execute_check(db_name='archery', sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) @patch('sql.engines.mysql.InceptionEngine') def test_execute_check_normal_sql(self, _inception_engine): sql = 'update user set id=1' row = ReviewResult(id=1, errlevel=0, stagestatus='Audit completed', errormessage='None', sql=sql, affected_rows=0, execute_time=0, ) _inception_engine.return_value.execute_check.return_value = ReviewSet(full_sql=sql, rows=[row]) new_engine = MysqlEngine(instance=self.ins1) check_result = new_engine.execute_check(db_name='archery', sql=sql) self.assertIsInstance(check_result, ReviewSet) self.assertEqual(check_result.rows[0].__dict__, row.__dict__) @patch('sql.engines.mysql.InceptionEngine') def test_execute_check_normal_sql_with_Exception(self, _inception_engine): sql = 'update user set id=1' _inception_engine.return_value.execute_check.side_effect = RuntimeError() new_engine = MysqlEngine(instance=self.ins1) with self.assertRaises(RuntimeError): new_engine.execute_check(db_name=0, sql=sql) @patch('sql.engines.mysql.InceptionEngine') def test_execute_workflow(self, _inception_engine): sql = 'update user set id=1' _inception_engine.return_value.execute.return_value = ReviewSet(full_sql=sql) new_engine = MysqlEngine(instance=self.ins1) execute_result = new_engine.execute_workflow(self.wf) self.assertIsInstance(execute_result, ReviewSet) @patch('MySQLdb.connect.cursor.execute') @patch('MySQLdb.connect.cursor') @patch('MySQLdb.connect') def test_execute(self, _connect, _cursor, _execute): new_engine = MysqlEngine(instance=self.ins1) execute_result = new_engine.execute(self.wf) self.assertIsInstance(execute_result, ResultSet) @patch.object(MysqlEngine, 'query') def test_server_version(self, _query): _query.return_value.rows = (('5.7.20',),) new_engine = MysqlEngine(instance=self.ins1) server_version = new_engine.server_version self.assertTupleEqual(server_version, (5, 7, 20)) @patch.object(MysqlEngine, 'query') def test_get_variables_not_filter(self, _query): new_engine = MysqlEngine(instance=self.ins1) new_engine.get_variables() _query.assert_called_once() @patch.object(MysqlEngine, 'query') def test_get_variables_filter(self, _query): new_engine = MysqlEngine(instance=self.ins1) new_engine.get_variables(variables=['binlog_format']) _query.assert_called() @patch.object(MysqlEngine, 'query') def test_set_variable(self, _query): new_engine = MysqlEngine(instance=self.ins1) new_engine.set_variable('binlog_format', 'ROW') _query.assert_called_once()
def test_query_masking(self, _data_masking): query_result = ResultSet() new_engine = MysqlEngine(instance=self.ins1) masking_result = new_engine.query_masking(db_name='archery', sql='select 1', resultset=query_result) self.assertIsInstance(masking_result, ResultSet)
def describe_table(self, db_name, tb_name): """获取表结构, 返回一个 ResultSet,rows=list""" return ResultSet()
def query(request): """ 获取SQL查询结果 :param request: :return: """ instance_name = request.POST.get('instance_name') sql_content = request.POST.get('sql_content') db_name = request.POST.get('db_name') limit_num = request.POST.get('limit_num') user = request.user result = {'status': 0, 'msg': 'ok', 'data': {}} try: instance = Instance.objects.get(instance_name=instance_name) except Instance.DoesNotExist: result['status'] = 1 result['msg'] = '实例不存在' return result # 服务器端参数验证 if sql_content is None or db_name is None or instance_name is None or limit_num is None: result['status'] = 1 result['msg'] = '页面提交参数可能为空' return HttpResponse(json.dumps(result), content_type='application/json') # 删除注释语句,进行语法判断 sql_content = sqlparse.format(sql_content.strip(), strip_comments=True) sql_list = sqlparse.split(sql_content) # 执行第一条有效sql sql_content = sql_list[0].rstrip(';') if re.match(r"^select|^show|^explain", sql_content, re.I) is None: result['status'] = 1 result['msg'] = '仅支持^select|^show|^explain语法,请联系管理员!' return HttpResponse(json.dumps(result), content_type='application/json') try: # 查询权限校验 priv_check_info = query_priv_check(user, instance_name, db_name, sql_content, limit_num) if priv_check_info['status'] == 0: limit_num = priv_check_info['data']['limit_num'] priv_check = priv_check_info['data']['priv_check'] else: result['status'] = priv_check_info['status'] result['msg'] = priv_check_info['msg'] data = ResultSet(full_sql=sql_content) data.error = priv_check_info['msg'] result['data'] = data.__dict__ return HttpResponse(json.dumps(result), content_type='application/json') limit_num = 0 if re.match(r"^explain", sql_content.lower()) else limit_num # 查询检查 query_engine = get_engine(instance=instance) filter_result = query_engine.query_check(db_name=db_name, sql=sql_content, limit_num=limit_num) if filter_result.get('bad_query'): # 引擎内部判断为 bad_query result['status'] = 1 result['msg'] = filter_result.get('msg') return HttpResponse(json.dumps(result), content_type='application/json') if filter_result.get('has_star') and SysConfig().get('disable_star') is True: # 引擎内部判断为有 * 且禁止 * 选项打开 result['status'] = 1 result['msg'] = filter_result.get('msg') return HttpResponse(json.dumps(result), content_type='application/json') else: sql_content = filter_result['filtered_sql'] sql_content = sql_content + ';' # 执行查询语句,统计执行时间 t_start = time.time() query_result = query_engine.query(db_name=str(db_name), sql=sql_content, limit_num=limit_num) t_end = time.time() query_result.query_time = "%5s" % "{:.4f}".format(t_end - t_start) # 数据脱敏,同样需要检查配置,是否开启脱敏,语法树解析是否允许出错继续执行 hit_rule = 0 if re.match(r"^select", sql_content.lower()) else 2 # 查询是否命中脱敏规则,0, '未知', 1, '命中', 2, '未命中' masking = 2 # 查询结果是否正常脱敏,1, '是', 2, '否' t_start = time.time() # 仅对正确查询的语句进行脱敏 if SysConfig().get('data_masking') and re.match(r"^select", sql_content.lower()) and query_result.error is None: try: query_result = query_engine.query_masking(db_name=db_name, sql=sql_content, resultset=query_result) if query_result.is_critical is True and SysConfig().get('query_check'): masking_result = {'status': query_result.status, 'msg': query_result.error, 'data': query_result.__dict__} return HttpResponse(json.dumps(masking_result), content_type='application/json') else: # 重置脱敏结果,返回未脱敏数据 query_result.status = 0 query_result.error = None # 实际未命中, 则显示为未做脱敏 if query_result.is_masked: masking = 1 hit_rule = 1 except Exception: logger.error(traceback.format_exc()) # 报错, 未脱敏, 未命中 hit_rule = 2 masking = 2 if SysConfig().get('query_check'): result['status'] = 1 result['msg'] = '脱敏数据报错,请联系管理员' return HttpResponse(json.dumps(result), content_type='application/json') t_end = time.time() query_result.mask_time = "%5s" % "{:.4f}".format(t_end - t_start) sql_result = query_result.__dict__ result['data'] = sql_result # 成功的查询语句记录存入数据库 if sql_result.get('error'): pass else: if int(limit_num) == 0: limit_num = int(sql_result['affected_rows']) else: limit_num = min(int(limit_num), int(sql_result['affected_rows'])) query_log = QueryLog( username=user.username, user_display=user.display, db_name=db_name, instance_name=instance.instance_name, sqllog=sql_content, effect_row=limit_num, cost_time=query_result.query_time, priv_check=priv_check, hit_rule=hit_rule, masking=masking ) # 防止查询超时 try: query_log.save() except: connection.close() query_log.save() except Exception as e: logger.error(traceback.format_exc()) result['status'] = 1 result['msg'] = str(e) # 返回查询结果 try: return HttpResponse(json.dumps(result, cls=ExtendJSONEncoder, bigint_as_string=True), content_type='application/json') except Exception: return HttpResponse(json.dumps(result, default=str, bigint_as_string=True, encoding='latin1'), content_type='application/json')
def get_all_columns_by_tb(self, db_name, tb_name): """获取所有字段, 返回一个ResultSet,rows=list""" return ResultSet()
def get_all_tables(self, db_name): """获取table 列表, 返回一个ResultSet,rows=list""" return ResultSet()
def get_all_databases(self): """获取数据库列表, 返回一个ResultSet,rows=list""" return ResultSet()
def set_variable(self, variable_name, variable_value): """修改实例参数值,返回一个 ResultSet""" return ResultSet()
def get_variables(self, variables=None): """获取实例参数,返回一个 ResultSet""" return ResultSet()