class TestSecurity(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['SECRET_KEY'] = 'secret_key' self.app.config['CSRF_ENABLED'] = False self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") role_admin = self.appbuilder.sm.find_role('Admin') self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') log.debug("Complete setup!") def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("Complete teardown!") def test_init_role_baseview(self): role_name = 'MyRole1' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] init_role(self.appbuilder.sm, role_name, role_vms, role_perms) role = self.appbuilder.sm.find_role(role_name) self.assertIsNotNone(role) self.assertEqual(len(role_perms), len(role.permissions)) def test_init_role_modelview(self): role_name = 'MyRole2' role_perms = ['can_list', 'can_show', 'can_add', 'can_edit', 'can_delete'] role_vms = ['SomeModelView'] init_role(self.appbuilder.sm, role_name, role_vms, role_perms) role = self.appbuilder.sm.find_role(role_name) self.assertIsNotNone(role) self.assertEqual(len(role_perms), len(role.permissions)) def test_invalid_perms(self): role_name = 'MyRole3' role_perms = ['can_foo'] role_vms = ['SomeBaseView'] with self.assertRaises(Exception) as context: init_role(self.appbuilder.sm, role_name, role_vms, role_perms) self.assertEqual("The following permissions are not valid: ['can_foo']", str(context.exception)) def test_invalid_vms(self): role_name = 'MyRole4' role_perms = ['can_some_action'] role_vms = ['NonExistentBaseView'] with self.assertRaises(Exception) as context: init_role(self.appbuilder.sm, role_name, role_vms, role_perms) self.assertEqual("The following view menus are not valid: " "['NonExistentBaseView']", str(context.exception))
def setUp(self): self.app = Flask(__name__) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['SECRET_KEY'] = 'secret_key' self.app.config['CSRF_ENABLED'] = False self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") role_admin = self.appbuilder.sm.find_role('Admin') self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') log.debug("Complete setup!")
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['MONGODB_SETTINGS'] = {'DB': 'test'} self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] edit_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G2']] } add_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G1']] } add_exclude_columns = ['excluded_string'] class Model22View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] add_exclude_columns = ['excluded_string'] edit_exclude_columns = ['excluded_string'] show_exclude_columns = ['excluded_string'] class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ['field_string', 'field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': [(aggregate_sum, 'field_integer', aggregate_avg, 'field_integer', aggregate_count, 'field_integer')] }] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': ['field_integer', 'field_float'] }] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role('Admin') try: self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') except: pass
except OSError: pass for middleware in app.config.get('ADDITIONAL_MIDDLEWARE'): app.wsgi_app = middleware(app.wsgi_app) class MyIndexView(IndexView): @expose('/') def index(self): return redirect('/superset/welcome') appbuilder = AppBuilder( app, db.session, base_template='superset/base.html', indexview=MyIndexView, security_manager_class=app.config.get("CUSTOM_SECURITY_MANAGER")) sm = appbuilder.sm get_session = appbuilder.get_session results_backend = app.config.get("RESULTS_BACKEND") # Registering sources module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP") module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP")) ConnectorRegistry.register_sources(module_datasource_map) from superset import views # noqa
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["MONGODB_SETTINGS"] = {"DB": "test"} self.app.config["CSRF_ENABLED"] = False self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["WTF_CSRF_ENABLED"] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "G2"]] } add_form_query_rel_fields = {"group": [["field_string", FilterEqual, "G1"]]} add_exclude_columns = ["excluded_string"] class Model22View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [["field_string", FilterStartsWith, "a"]] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ { "group": "field_string", "series": [ ( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", ) ], } ] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ {"group": "field_string", "series": ["field_integer", "field_float"]} ] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view( Model1MasterChartView, "Model1MasterChart", category="Model1" ) self.appbuilder.add_view( Model1Filtered1View, "Model1Filtered1", category="Model1" ) self.appbuilder.add_view( Model1Filtered2View, "Model1Filtered2", category="Model1" ) self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role("Admin") try: self.appbuilder.sm.add_user( "admin", "admin", "user", "*****@*****.**", role_admin, "general" ) except Exception: pass
from flask import Flask, request, render_template from flask_sqlalchemy import SQLAlchemy from .config import config, Config from flask_appbuilder import AppBuilder import jpush db = SQLAlchemy() _jpush = jpush.JPush(Config.app_key, Config.master_secret) _jpush.set_logging('DEBUG') from app.controller import MyIndexView appbuilder = AppBuilder(indexview=MyIndexView) api_version = 'v1' def create_app(config_name): app = Flask(__name__) app.config.from_object(config[config_name]) print(app.config['SQLALCHEMY_DATABASE_URI']) db.app = app db.init_app(app) appbuilder.app = app appbuilder.init_app(app, db.session) from app.controller import Auth, UserError, Message from app.model import User #@app.before_request #def before_request(): # user_id = request.form.get('user_id') # token = request.headers.get('Authorization') # if not Auth.authToken(user_id, token): # return str(Message(None, *UserError.AUTH_FAILED))
import logging from flask import Flask from flask_appbuilder import SQLA, AppBuilder from .index import VisulizeIndexView """ Logging configuration """ logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s') logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_object('config') db = SQLA(app) appbuilder = AppBuilder(app, db.session, indexview=VisulizeIndexView) """ from sqlalchemy.engine import Engine from sqlalchemy import event #Only include this for SQLLite constraints @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() """
class FlaskTestCase(FABTestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView from sqlalchemy.engine import Engine from sqlalchemy import event self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" self.app.config["CSRF_ENABLED"] = False self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["WTF_CSRF_ENABLED"] = False self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False self.app.config["FAB_ROLES"] = { "ReadOnly": [ [".*", "can_list"], [".*", "can_show"] ] } logging.basicConfig(level=logging.ERROR) @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ["can_list", "can_show"] list_columns = ["UID", "C", "CMD", "TIME"] search_columns = ["UID", "C", "CMD"] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "G2"]] } add_form_query_rel_fields = {"group": [["field_string", FilterEqual, "G1"]]} class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model3View(ModelView): datamodel = SQLAInterface(Model3) list_columns = ["pk1", "pk2", "field_string"] add_columns = ["pk1", "pk2", "field_string"] edit_columns = ["pk1", "pk2", "field_string"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model3CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model3) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) obj_id = 1 def post_add_redirect(self): return redirect( "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID) ) def post_edit_redirect(self): return redirect( "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID) ) def post_delete_redirect(self): return redirect( "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID) ) class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_string", FilterStartsWith, "a"]] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_string"] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ { "group": "field_string", "series": [ ( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", ) ], } ] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" list_title = "" definitions = [ {"group": "field_string", "series": ["field_integer", "field_float"]} ] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_date"] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ["field_string"] show_columns = ["field_string"] formatters_columns = {"field_string": lambda x: "FORMATTED_STRING"} class ModelWithEnumsView(ModelView): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view( Model1ViewWithRedirects, "Model1ViewWithRedirects", category="Model1" ) self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view( Model1MasterChartView, "Model1MasterChart", category="Model1" ) self.appbuilder.add_view( Model1Filtered1View, "Model1Filtered1", category="Model1" ) self.appbuilder.add_view( Model1Filtered2View, "Model1Filtered2", category="Model1" ) self.appbuilder.add_view( Model1FormattedView, "Model1FormattedView", category="Model1FormattedView" ) self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(Model3View, "Model3") self.appbuilder.add_view(Model3CompactView, "Model3Compact") self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums") self.appbuilder.add_view(PSView, "Generic DS PS View", category="PSView") role_admin = self.appbuilder.sm.find_role("Admin") self.appbuilder.sm.add_user( "admin", "admin", "user", "*****@*****.**", role_admin, "general" ) role_read_only = self.appbuilder.sm.find_role("ReadOnly") self.appbuilder.sm.add_user( USERNAME_READONLY, "readonly", "readonly", "*****@*****.**", role_read_only, PASSWORD_READONLY ) def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) self.db.session.add(model) self.db.session.commit() def insert_data2(self): models1 = [ Model1(field_string="G1"), Model1(field_string="G2"), Model1(field_string="G3"), ] for model1 in models1: try: self.db.session.add(model1) self.db.session.commit() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2( field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1, ) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) self.db.session.add(model) self.db.session.commit() except Exception as e: print("ERROR {0}".format(str(e))) self.db.session.rollback() def insert_data3(self): model3 = Model3(pk1=3, pk2=datetime.datetime(2017, 3, 3), field_string="foo") try: self.db.session.add(model3) self.db.session.commit() except Exception as e: print("Error {0}".format(str(e))) self.db.session.rollback() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 34) def test_back(self): """ Test Back functionality """ with self.app.test_client() as c: self.browser_login(c, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) c.get("/model1view/list/?_flt_0_field_string=f") c.get("/model2view/list/") c.get("/back", follow_redirects=True) assert request.args["_flt_0_field_string"] == u"f" assert "/model1view/list/" == request.path def test_model_creation(self): """ Test Model creation """ from sqlalchemy.engine.reflection import Inspector engine = self.db.session.get_bind(mapper=None, clause=None) inspector = Inspector.from_engine(engine) # Check if tables exist ok_("model1" in inspector.get_table_names()) ok_("model2" in inspector.get_table_names()) ok_("model3" in inspector.get_table_names()) ok_("model_with_enums" in inspector.get_table_names()) def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get("/") data = rv.data.decode("utf-8") ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get("/model1view/list/") eq_(rv.status_code, 302) rv = client.get("/model2view/list/") eq_(rv.status_code, 302) # Login and list with admin self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/model1view/list/") eq_(rv.status_code, 200) rv = client.get("/model2view/list/") eq_(rv.status_code, 200) # Logout and and try to list self.browser_logout(client) rv = client.get("/model1view/list/") eq_(rv.status_code, 302) rv = client.get("/model2view/list/") eq_(rv.status_code, 302) # Invalid Login rv = self.browser_login(client, DEFAULT_ADMIN_USER, "password") data = rv.data.decode("utf-8") ok_(INVALID_LOGIN_STRING in data) def test_auth_builtin_roles(self): """ Test Security builtin roles readonly """ self.insert_data() client = self.app.test_client() self.browser_login(client, USERNAME_READONLY, PASSWORD_READONLY) # Test unauthorized GET rv = client.get("/model1view/list/") eq_(rv.status_code, 200) # Test unauthorized EDIT rv = client.get("/model1view/show/1") eq_(rv.status_code, 200) rv = client.get("/model1view/edit/1") eq_(rv.status_code, 302) # Test unauthorized DELETE rv = client.get("/model1view/delete/1") eq_(rv.status_code, 302) def test_sec_reset_password(self): """ Test Security reset password """ client = self.app.test_client() # Try Reset My password rv = client.get("/users/action/resetmypassword/1", follow_redirects=True) data = rv.data.decode("utf-8") ok_(ACCESS_IS_DENIED in data) # Reset My password rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/users/action/resetmypassword/1", follow_redirects=True) data = rv.data.decode("utf-8") ok_("Reset Password Form" in data) rv = client.post( "/resetmypassword/form", data=dict(password="******", conf_password="******"), follow_redirects=True, ) eq_(rv.status_code, 200) self.browser_logout(client) self.browser_login(client, DEFAULT_ADMIN_USER, "password") rv = client.post( "/resetmypassword/form", data=dict( password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD ), follow_redirects=True, ) eq_(rv.status_code, 200) # Reset Password Admin rv = client.get("/users/action/resetpasswords/1", follow_redirects=True) data = rv.data.decode("utf-8") ok_("Reset Password Form" in data) rv = client.post( "/resetmypassword/form", data=dict( password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD ), follow_redirects=True, ) eq_(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/psview/list") rv.data.decode("utf-8") def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/add", data=dict( field_string="test1", field_integer="1", field_float="0.12", field_date="2014-01-01", ), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u"test1") eq_(model.field_integer, 1) rv = client.post( "/model1view/edit/1", data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u"test2") eq_(model.field_integer, 2) rv = client.get("/model1view/delete/1", follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model, None) def test_model_crud_composite_pk(self): """ Test Generic Interface for generic-alter datasource where model has composite primary keys """ try: from urllib import quote except Exception: from urllib.parse import quote client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model3view/add", data=dict(pk1="1", pk2="2017-01-01 00:00:00", field_string="foo"), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model3).first() eq_(model.pk1, 1) eq_(model.pk2, datetime.datetime(2017, 1, 1)) eq_(model.field_string, u"foo") pk = '[1, {"_type": "datetime", "value": "2017-01-01T00:00:00.000000"}]' rv = client.get("/model3view/show/" + quote(pk), follow_redirects=True) eq_(rv.status_code, 200) rv = client.post( "/model3view/edit/" + quote(pk), data=dict(pk1="2", pk2="2017-02-02 00:00:00", field_string="bar"), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model3).first() eq_(model.pk1, 2) eq_(model.pk2, datetime.datetime(2017, 2, 2)) eq_(model.field_string, u"bar") pk = '[2, {"_type": "datetime", "value": "2017-02-02T00:00:00.000000"}]' rv = client.get("/model3view/delete/" + quote(pk), follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model3).first() eq_(model, None) def test_model_crud_with_enum(self): """ Test Model add, delete, edit for Model with Enum Columns """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) data = {"enum1": u"e1", "enum2": "e1"} rv = client.post("/modelwithenumsview/add", data=data, follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(ModelWithEnums).first() eq_(model.enum1, u"e1") eq_(model.enum2, TmpEnum.e1) data = {"enum1": u"e2", "enum2": "e2"} rv = client.post("/modelwithenumsview/edit/1", data=data, follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(ModelWithEnums).first() eq_(model.enum1, u"e2") eq_(model.enum2, TmpEnum.e2) rv = client.get("/modelwithenumsview/delete/1", follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(ModelWithEnums).first() eq_(model, None) def test_formatted_cols(self): """ Test ModelView's formatters_columns """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() rv = client.get("/model1formattedview/list/") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("FORMATTED_STRING" in data) rv = client.get("/model1formattedview/show/1") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("FORMATTED_STRING" in data) def test_model_redirects(self): """ Test Model redirects after add, delete, edit """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) model1 = Model1(field_string="Test Redirects") self.db.session.add(model1) model1.id = REDIRECT_OBJ_ID self.db.session.flush() rv = client.post( "/model1viewwithredirects/add", data=dict( field_string="test_redirect", field_integer="1", field_float="0.12", field_date="2014-01-01", ), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("Test Redirects" in data) model_id = ( self.db.session.query(Model1) .filter_by(field_string="test_redirect") .first() .id ) rv = client.post( "/model1viewwithredirects/edit/{0}".format(model_id), data=dict(field_string="test_redirect_2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) ok_("Test Redirects" in data) rv = client.get( "/model1viewwithredirects/delete/{0}".format(model_id), follow_redirects=True, ) eq_(rv.status_code, 200) ok_("Test Redirects" in data) def test_excluded_cols(self): """ Test add_exclude_columns, edit_exclude_columns, show_exclude_columns """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/model22view/add") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_string" in data) ok_("field_integer" in data) ok_("field_float" in data) ok_("field_date" in data) ok_("excluded_string" not in data) self.insert_data2() rv = client.get("/model22view/edit/1") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_string" in data) ok_("field_integer" in data) ok_("field_float" in data) ok_("field_date" in data) ok_("excluded_string" not in data) rv = client.get("/model22view/show/1") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("Field String" in data) ok_("Field Integer" in data) ok_("Field Float" in data) ok_("Field Date" in data) ok_("Excluded String" not in data) def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get("/model2view/add") data = rv.data.decode("utf-8") ok_("G1" in data) ok_("G2" not in data) # Base filter string starts with rv = client.get("/model2view/edit/1") data = rv.data.decode("utf-8") ok_("G2" in data) ok_("G1" not in data) def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc", follow_redirects=True, ) # TODO: Fix this 405 error # eq_(rv.status_code, 200) rv.data.decode("utf-8") # TODO # VALIDATE LIST IS ORDERED rv = client.post( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc", follow_redirects=True, ) # TODO: Fix this 405 error # eq_(rv.status_code, 200) rv.data.decode("utf-8") # TODO # VALIDATE LIST IS ORDERED def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="1"), follow_redirects=True, ) eq_(rv.status_code, 200) rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(UNIQUE_VALIDATION_STRING in data) model = self.db.session.query(Model1).all() eq_(len(model), 1) rv = client.post( "/model1view/add", data=dict(field_string="", field_integer="1"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(NOTNULL_VALIDATION_STRING in data) model = self.db.session.query(Model1).all() eq_(len(model), 1) def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post( "/model1view/add", data=dict(field_string="test1", field_integer="1"), follow_redirects=True, ) client.post( "/model1view/add", data=dict(field_string="test2", field_integer="1"), follow_redirects=True, ) rv = client.post( "/model1view/edit/1", data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post( "/model1view/edit/1", data=dict(field_string="", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(NOTNULL_VALIDATION_STRING in data) def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = self.db.session.query(Model1).all() eq_(len(models), 23) # Base filter string starts with rv = client.get("/model1filtered1view/list/") data = rv.data.decode("utf-8") ok_("atest" in data) ok_("btest" not in data) # Base filter integer equals rv = client.get("/model1filtered2view/list/") data = rv.data.decode("utf-8") ok_("atest" in data) ok_("btest" not in data) def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model2view/list/") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_method_value" in data) def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model1compactview/list/") eq_(rv.status_code, 200) # test with composite pk try: from urllib import quote except Exception: from urllib.parse import quote self.insert_data3() pk = '[3, {"_type": "datetime", "value": "2017-03-03T00:00:00"}]' rv = client.post( "/model3compactview/edit/" + quote(pk), data=dict(field_string="bar"), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model3).first() eq_(model.field_string, u"bar") rv = client.get("/model3compactview/delete/" + quote(pk), follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model3).first() eq_(model, None) def test_edit_add_form_action_prefix_for_compactCRUDMixin(self): """ Test form_action in add, form_action in edit (CompactCRUDMixin) """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) # Make sure we have something to edit. self.insert_data() prefix = "/some-prefix" base_url = "http://localhost" + prefix session_form_action_key = "Model1CompactView__session_form_action" with client as c: expected_form_action = prefix + "/model1compactview/add/?" c.get("/model1compactview/add/", base_url=base_url) ok_(session[session_form_action_key] == expected_form_action) expected_form_action = prefix + "/model1compactview/edit/1?" c.get("/model1compactview/edit/1", base_url=base_url) ok_(session[session_form_action_key] == expected_form_action) def test_charts_view(self): """ Test Various Chart views """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() log.info("CHART TEST") rv = client.get("/model2chartview/chart/") eq_(rv.status_code, 200) rv = client.get("/model2groupbychartview/chart/") eq_(rv.status_code, 200) rv = client.get("/model2directbychartview/chart/") eq_(rv.status_code, 200) rv = client.get("/model2timechartview/chart/") eq_(rv.status_code, 200) # TODO: fix this # rv = client.get('/model2directchartview/chart/') # eq_(rv.status_code, 200) def test_master_detail_view(self): """ Test Master detail view """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model1masterview/list/") eq_(rv.status_code, 200) rv = client.get("/model1masterview/list/1") eq_(rv.status_code, 200) rv = client.get("/model1masterchartview/list/") eq_(rv.status_code, 200) rv = client.get("/model1masterchartview/list/1") eq_(rv.status_code, 200) def test_api_read(self): """ Testing the api/read endpoint """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() rv = client.get("/model1formattedview/api/read") eq_(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) assert "result" in data assert "pks" in data assert len(data.get("result")) > 10 def test_api_create(self): """ Testing the api/create endpoint """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/api/create", data=dict(field_string="zzz"), follow_redirects=True, ) eq_(rv.status_code, 200) objs = self.db.session.query(Model1).all() eq_(len(objs), 1) def test_api_update(self): """ Validate that the api update endpoint updates [only] the fields in POST data """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() item = self.db.session.query(Model1).filter_by(id=1).one() field_integer_before = item.field_integer rv = client.put( "/model1view/api/update/1", data=dict(field_string="zzz"), follow_redirects=True, ) eq_(rv.status_code, 200) item = self.db.session.query(Model1).filter_by(id=1).one() eq_(item.field_string, "zzz") eq_(item.field_integer, field_integer_before) def test_class_method_permission_override(self): """ MVC: Test class method permission name override """ from flask_appbuilder import ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermOverride(ModelView): datamodel = SQLAInterface(Model1) class_permission_name = 'view' method_permission_name = { "list": "access", "show": "access", "edit": "access", "add": "access", "delete": "access", "download": "access", "api_readvalues": "access", "api_column_edit": "access", "api_column_add": "access", "api_delete": "access", "api_update": "access", "api_create": "access", "api_get": "access", "api_read": "access", "api": "access" } self.model1permoverride = Model1PermOverride self.appbuilder.add_view_no_menu(Model1PermOverride) role = self.appbuilder.sm.add_role("Test") pvm = self.appbuilder.sm.find_permission_view_menu( "can_access", "view" ) self.appbuilder.sm.add_permission_role(role, pvm) self.appbuilder.sm.add_user( "test", "test", "user", "*****@*****.**", role, "test" ) client = self.app.test_client() self.browser_login(client, "test", "test") rv = client.get("/model1permoverride/list/") eq_(rv.status_code, 200) rv = client.post( "/model1permoverride/add", data=dict( field_string="test1", field_integer="1", field_float="0.12", field_date="2014-01-01", ), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u"test1") eq_(model.field_integer, 1) def test_method_permission_override(self): """ MVC: Test method permission name override """ from flask_appbuilder import ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermOverride(ModelView): datamodel = SQLAInterface(Model1) method_permission_name = { "list": "read", "show": "read", "edit": "write", "add": "write", "delete": "write", "download": "read", "api_readvalues": "read", "api_column_edit": "write", "api_column_add": "write", "api_delete": "write", "api_update": "write", "api_create": "write", "api_get": "read", "api_read": "read", "api": "read" } self.model1permoverride = Model1PermOverride self.appbuilder.add_view_no_menu(Model1PermOverride) role = self.appbuilder.sm.add_role("Test") pvm_read = self.appbuilder.sm.find_permission_view_menu( "can_read", "Model1PermOverride" ) pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride" ) self.appbuilder.sm.add_permission_role(role, pvm_read) self.appbuilder.sm.add_permission_role(role, pvm_write) self.appbuilder.sm.add_user( "test", "test", "user", "*****@*****.**", role, "test" ) client = self.app.test_client() self.browser_login(client, "test", "test") rv = client.post( "/model1permoverride/add", data=dict( field_string="test1", field_integer="1", field_float="0.12", field_date="2014-01-01", ), follow_redirects=True, ) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u"test1") eq_(model.field_integer, 1) # Verify write links are on the UI rv = client.get("/model1permoverride/list/") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("/model1permoverride/delete/1" in data) ok_("/model1permoverride/add" in data) ok_("/model1permoverride/edit/1" in data) ok_("/model1permoverride/show/1" in data) # Delete write permission from Test Role role = self.appbuilder.sm.find_role('Test') pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride" ) self.appbuilder.sm.del_permission_role(role, pvm_write) # Unauthorized delete rv = client.get("/model1permoverride/delete/1") eq_(rv.status_code, 302) model = self.db.session.query(Model1).first() eq_(model.field_string, u"test1") eq_(model.field_integer, 1) # Verify write links are gone from UI rv = client.get("/model1permoverride/list/") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("/model1permoverride/delete/1" not in data) ok_("/model1permoverride/add/" not in data) ok_("/model1permoverride/edit/1" not in data) ok_("/model1permoverride/show/1" in data) def test_action_permission_override(self): """ MVC: Test action permission name override """ from flask_appbuilder import action, ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermOverride(ModelView): datamodel = SQLAInterface(Model1) method_permission_name = { "list": "read", "show": "read", "edit": "write", "add": "write", "delete": "write", "download": "read", "api_readvalues": "read", "api_column_edit": "write", "api_column_add": "write", "api_delete": "write", "api_update": "write", "api_create": "write", "api_get": "read", "api_read": "read", "api": "read", "action_one": "write" } @action("action1", "Action1", "", "fa-lock", multiple=True) def action_one(self, item): return "ACTION ONE" self.model1permoverride = Model1PermOverride self.appbuilder.add_view_no_menu(Model1PermOverride) # Add a user and login before enabling CSRF role = self.appbuilder.sm.add_role("Test") self.appbuilder.sm.add_user( "test", "test", "user", "*****@*****.**", role, "test" ) pvm_read = self.appbuilder.sm.find_permission_view_menu( "can_read", "Model1PermOverride" ) pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride" ) self.appbuilder.sm.add_permission_role(role, pvm_read) self.appbuilder.sm.add_permission_role(role, pvm_write) client = self.app.test_client() self.browser_login(client, "test", "test") rv = client.get("/model1permoverride/action/action1/1") eq_(rv.status_code, 200) # Delete write permission from Test Role role = self.appbuilder.sm.find_role('Test') pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride" ) self.appbuilder.sm.del_permission_role(role, pvm_write) rv = client.get("/model1permoverride/action/action1/1") eq_(rv.status_code, 302) def test_permission_converge_compress(self): """ MVC: Test permission name converge compress """ from flask_appbuilder import ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermConverge(ModelView): datamodel = SQLAInterface(Model1) class_permission_name = 'view2' previous_class_permission_name = 'Model1View' method_permission_name = { "list": "access", "show": "access", "edit": "access", "add": "access", "delete": "access", "download": "access", "api_readvalues": "access", "api_column_edit": "access", "api_column_add": "access", "api_delete": "access", "api_update": "access", "api_create": "access", "api_get": "access", "api_read": "access", "api": "access" } self.appbuilder.add_view_no_menu(Model1PermConverge) role = self.appbuilder.sm.add_role("Test") pvm = self.appbuilder.sm.find_permission_view_menu( "can_list", "Model1View" ) self.appbuilder.sm.add_permission_role(role, pvm) pvm = self.appbuilder.sm.find_permission_view_menu( "can_add", "Model1View" ) self.appbuilder.sm.add_permission_role(role, pvm) role = self.appbuilder.sm.find_role("Test") self.appbuilder.sm.add_user( "test", "test", "user", "*****@*****.**", role, "test" ) # Remove previous class, Hack to test code change for i, baseview in enumerate(self.appbuilder.baseviews): if baseview.__class__.__name__ == "Model1View": break self.appbuilder.baseviews.pop(i) target_state_transitions = { 'add': { ('Model1View', 'can_edit'): {('view2', 'can_access')}, ('Model1View', 'can_add'): {('view2', 'can_access')}, ('Model1View', 'can_list'): {('view2', 'can_access')}, ('Model1View', 'can_download'): {('view2', 'can_access')}, ('Model1View', 'can_show'): {('view2', 'can_access')}, ('Model1View', 'can_delete'): {('view2', 'can_access')} }, 'del_role_pvm': { ('Model1View', 'can_show'), ('Model1View', 'can_add'), ('Model1View', 'can_download'), ('Model1View', 'can_list'), ('Model1View', 'can_edit'), ('Model1View', 'can_delete') }, 'del_views': { 'Model1View' }, 'del_perms': set() } state_transitions = self.appbuilder.security_converge() eq_(state_transitions, target_state_transitions) role = self.appbuilder.sm.find_role("Test") eq_(len(role.permissions), 1)
class FlaskTestCase(unittest.TestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ['can_list', 'can_show'] list_columns = ['UID', 'C', 'CMD', 'TIME'] search_columns = ['UID', 'C', 'CMD'] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] edit_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G2']] } add_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G1']] } class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] add_exclude_columns = ['excluded_string'] edit_exclude_columns = ['excluded_string'] show_exclude_columns = ['excluded_string'] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ['field_string', 'field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) obj_id = 1 def post_add_redirect(self): return redirect( 'model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_edit_redirect(self): return redirect( 'model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_delete_redirect(self): return redirect( 'model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_string'] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': [(aggregate_sum, 'field_integer', aggregate_avg, 'field_integer', aggregate_count, 'field_integer')] }] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': ['field_integer', 'field_float'] }] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_date'] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ['field_string'] show_columns = ['field_string'] formatters_columns = { 'field_string': lambda x: 'FORMATTED_STRING', } self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1ViewWithRedirects, "Model1ViewWithRedirects", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model1FormattedView, "Model1FormattedView", category='Model1FormattedView') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(PSView, "Generic DS PS View", category='PSView') role_admin = self.appbuilder.sm.find_role('Admin') self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def login(self, client, username, password): # Login with default admin return client.post('/login/', data=dict(username=username, password=password), follow_redirects=True) def logout(self, client): return client.get('/logout/') def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) self.db.session.add(model) self.db.session.commit() def insert_data2(self): models1 = [ Model1(field_string='G1'), Model1(field_string='G2'), Model1(field_string='G3') ] for model1 in models1: try: self.db.session.add(model1) self.db.session.commit() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2(field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) self.db.session.add(model) self.db.session.commit() except Exception as e: print("ERROR {0}".format(str(e))) self.db.session.rollback() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 29) # current minimal views are 12 def test_back(self): """ Test Back functionality """ with self.app.test_client() as c: self.login(c, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = c.get('/model1view/list/?_flt_0_field_string=f') rv = c.get('/model2view/list/') rv = c.get('/back', follow_redirects=True) assert request.args['_flt_0_field_string'] == u'f' assert '/model1view/list/' == request.path def test_model_creation(self): """ Test Model creation """ from sqlalchemy.engine.reflection import Inspector engine = self.db.session.get_bind(mapper=None, clause=None) inspector = Inspector.from_engine(engine) # Check if tables exist ok_('model1' in inspector.get_table_names()) ok_('model2' in inspector.get_table_names()) def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get('/') data = rv.data.decode('utf-8') ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Login and list with admin self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model1view/list/') eq_(rv.status_code, 200) rv = client.get('/model2view/list/') eq_(rv.status_code, 200) # Logout and and try to list self.logout(client) rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Invalid Login rv = self.login(client, DEFAULT_ADMIN_USER, 'password') data = rv.data.decode('utf-8') ok_(INVALID_LOGIN_STRING in data) def test_sec_reset_password(self): """ Test Security reset password """ client = self.app.test_client() # Try Reset My password rv = client.get('/users/action/resetmypassword/1', follow_redirects=True) data = rv.data.decode('utf-8') ok_(ACCESS_IS_DENIED in data) #Reset My password rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/users/action/resetmypassword/1', follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password='******', conf_password='******'), follow_redirects=True) eq_(rv.status_code, 200) self.logout(client) self.login(client, DEFAULT_ADMIN_USER, 'password') rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) #Reset Password Admin rv = client.get('/users/action/resetpasswords/1', follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/psview/list') data = rv.data.decode('utf-8') def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) #with open('test.txt', 'rb') as fp: rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1', field_float='0.12', field_date='2014-01-01'), follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u'test1') eq_(model.field_integer, 1) rv = client.post('/model1view/edit/1', data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u'test2') eq_(model.field_integer, 2) rv = client.get('/model1view/delete/1', follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model, None) def test_formatted_cols(self): """ Test ModelView's formatters_columns """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() rv = client.get('/model1formattedview/list/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('FORMATTED_STRING' in data) rv = client.get('/model1formattedview/show/1') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('FORMATTED_STRING' in data) def test_model_redirects(self): """ Test Model redirects after add, delete, edit """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) model1 = Model1(field_string='Test Redirects') self.db.session.add(model1) model1.id = REDIRECT_OBJ_ID self.db.session.flush() rv = client.post('/model1viewwithredirects/add', data=dict(field_string='test_redirect', field_integer='1', field_float='0.12', field_date='2014-01-01'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('Test Redirects' in data) model_id = self.db.session.query(Model1).filter_by( field_string='test_redirect').first().id rv = client.post('/model1viewwithredirects/edit/{0}'.format(model_id), data=dict(field_string='test_redirect_2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) ok_('Test Redirects' in data) rv = client.get('/model1viewwithredirects/delete/{0}'.format(model_id), follow_redirects=True) eq_(rv.status_code, 200) ok_('Test Redirects' in data) def test_excluded_cols(self): """ Test add_exclude_columns, edit_exclude_columns, show_exclude_columns """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model22view/add') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_string' in data) ok_('field_integer' in data) ok_('field_float' in data) ok_('field_date' in data) ok_('excluded_string' not in data) self.insert_data2() rv = client.get('/model22view/edit/1') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_string' in data) ok_('field_integer' in data) ok_('field_float' in data) ok_('field_date' in data) ok_('excluded_string' not in data) rv = client.get('/model22view/show/1') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('Field String' in data) ok_('Field Integer' in data) ok_('Field Float' in data) ok_('Field Date' in data) ok_('Excluded String' not in data) def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get('/model2view/add') data = rv.data.decode('utf-8') ok_('G1' in data) ok_('G2' not in data) # Base filter string starts with rv = client.get('/model2view/edit/1') data = rv.data.decode('utf-8') ok_('G2' in data) ok_('G1' not in data) def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( '/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc', follow_redirects=True) # TODO: Fix this 405 error # eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED rv = client.post( '/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc', follow_redirects=True) # TODO: Fix this 405 error # eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) model = self.db.session.query(Model1).all() eq_(len(model), 1) rv = client.post('/model1view/add', data=dict(field_string='', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) model = self.db.session.query(Model1).all() eq_(len(model), 1) def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) client.post('/model1view/add', data=dict(field_string='test2', field_integer='1'), follow_redirects=True) rv = client.post('/model1view/edit/1', data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post('/model1view/edit/1', data=dict(field_string='', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = self.db.session.query(Model1).all() eq_(len(models), 23) # Base filter string starts with rv = client.get('/model1filtered1view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) # Base filter integer equals rv = client.get('/model1filtered2view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model2view/list/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_method_value' in data) def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model1compactview/list/') eq_(rv.status_code, 200) def test_charts_view(self): """ Test Various Chart views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() log.info("CHART TEST") rv = client.get('/model2chartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2groupbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2directbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2timechartview/chart/') eq_(rv.status_code, 200) # TODO: fix this # rv = client.get('/model2directchartview/chart/') #eq_(rv.status_code, 200) def test_master_detail_view(self): """ Test Master detail view """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model1masterview/list/') eq_(rv.status_code, 200) rv = client.get('/model1masterview/list/1') eq_(rv.status_code, 200) rv = client.get('/model1masterchartview/list/') eq_(rv.status_code, 200) rv = client.get('/model1masterchartview/list/1') eq_(rv.status_code, 200)
from flask_jsglue import JSGlue from .index import MyIndexView from flask_bootstrap import Bootstrap from flask_ngrok import run_with_ngrok logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(message)s") logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) jsglue = JSGlue(app) bootstrap = Bootstrap(app) app.config.from_object("config") db = SQLA(app) appbuilder = AppBuilder(app, db.session, menu=Menu(reverse=False), security_manager_class=MySecurityManager, indexview=MyIndexView) """ from sqlalchemy.engine import Engine from sqlalchemy import event #Only include this for SQLLite constraints @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() """ from .models import *
from .security import SecurityManager # Import Flask appbuilder functions to create the appbuilder object from flask_appbuilder import SQLA, AppBuilder """ Logging configuration """ logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s') logging.getLogger().setLevel(logging.DEBUG) # Create flask app object app = Flask(__name__) # Get Configs from config.py app.config.from_object('config') # Create Database object frmo flask app object db = SQLA(app) # Create Appbuilder object from db and flask app object and customized classes appbuilder = AppBuilder(app, db.session, indexview=MyIndexView, security_manager_class=SecurityManager) # Import views for running app from app import views # Uncomment code to fix table error # from app.models import CustomUser # CustomUser.add_column(db) # from app.models import Classes # Classes.add_column(db)
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ['can_list', 'can_show'] list_columns = ['UID', 'C', 'CMD', 'TIME'] search_columns = ['UID', 'C', 'CMD'] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = ['field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string'] edit_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G2']]} add_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G1']]} class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = ['field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string'] add_exclude_columns = ['excluded_string'] edit_exclude_columns = ['excluded_string'] show_exclude_columns = ['excluded_string'] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ['field_string', 'field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) obj_id = 1 def post_add_redirect(self): return redirect('model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_edit_redirect(self): return redirect('model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_delete_redirect(self): return redirect('model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_string'] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':[(aggregate_sum,'field_integer', aggregate_avg, 'field_integer', aggregate_count,'field_integer') ] } ] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':['field_integer','field_float'] } ] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_date'] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ['field_string'] show_columns = ['field_string'] formatters_columns = { 'field_string': lambda x: 'FORMATTED_STRING', } class ModelWithEnumsView(ModelView): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1ViewWithRedirects, "Model1ViewWithRedirects", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model1FormattedView, "Model1FormattedView", category='Model1FormattedView') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums") self.appbuilder.add_view(PSView, "Generic DS PS View", category='PSView') role_admin = self.appbuilder.sm.find_role('Admin') self.appbuilder.sm.add_user('admin','admin','user','*****@*****.**',role_admin,'general')
class FlaskTestCase(unittest.TestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ['can_list', 'can_show'] list_columns = ['UID', 'C', 'CMD', 'TIME'] search_columns = ['UID', 'C', 'CMD'] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = ['field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string'] edit_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G2']]} add_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G1']]} class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = ['field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string'] add_exclude_columns = ['excluded_string'] edit_exclude_columns = ['excluded_string'] show_exclude_columns = ['excluded_string'] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ['field_string', 'field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) obj_id = 1 def post_add_redirect(self): return redirect('model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_edit_redirect(self): return redirect('model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_delete_redirect(self): return redirect('model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_string'] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':[(aggregate_sum,'field_integer', aggregate_avg, 'field_integer', aggregate_count,'field_integer') ] } ] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':['field_integer','field_float'] } ] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_date'] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ['field_string'] show_columns = ['field_string'] formatters_columns = { 'field_string': lambda x: 'FORMATTED_STRING', } class ModelWithEnumsView(ModelView): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1ViewWithRedirects, "Model1ViewWithRedirects", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model1FormattedView, "Model1FormattedView", category='Model1FormattedView') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums") self.appbuilder.add_view(PSView, "Generic DS PS View", category='PSView') role_admin = self.appbuilder.sm.find_role('Admin') self.appbuilder.sm.add_user('admin','admin','user','*****@*****.**',role_admin,'general') def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def login(self, client, username, password): # Login with default admin return client.post('/login/', data=dict( username=username, password=password ), follow_redirects=True) def logout(self, client): return client.get('/logout/') def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) self.db.session.add(model) self.db.session.commit() def insert_data2(self): models1 = [Model1(field_string='G1'), Model1(field_string='G2'), Model1(field_string='G3')] for model1 in models1: try: self.db.session.add(model1) self.db.session.commit() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2(field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) self.db.session.add(model) self.db.session.commit() except Exception as e: print("ERROR {0}".format(str(e))) self.db.session.rollback() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 30) # current minimal views are 12 def test_back(self): """ Test Back functionality """ with self.app.test_client() as c: self.login(c, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = c.get('/model1view/list/?_flt_0_field_string=f') rv = c.get('/model2view/list/') rv = c.get('/back', follow_redirects=True) assert request.args['_flt_0_field_string'] == u'f' assert '/model1view/list/' == request.path def test_model_creation(self): """ Test Model creation """ from sqlalchemy.engine.reflection import Inspector engine = self.db.session.get_bind(mapper=None, clause=None) inspector = Inspector.from_engine(engine) # Check if tables exist ok_('model1' in inspector.get_table_names()) ok_('model2' in inspector.get_table_names()) ok_('model_with_enums' in inspector.get_table_names()) def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get('/') data = rv.data.decode('utf-8') ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Login and list with admin self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model1view/list/') eq_(rv.status_code, 200) rv = client.get('/model2view/list/') eq_(rv.status_code, 200) # Logout and and try to list self.logout(client) rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Invalid Login rv = self.login(client, DEFAULT_ADMIN_USER, 'password') data = rv.data.decode('utf-8') ok_(INVALID_LOGIN_STRING in data) def test_sec_reset_password(self): """ Test Security reset password """ client = self.app.test_client() # Try Reset My password rv = client.get('/users/action/resetmypassword/1', follow_redirects=True) data = rv.data.decode('utf-8') ok_(ACCESS_IS_DENIED in data) #Reset My password rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/users/action/resetmypassword/1', follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password='******', conf_password='******'), follow_redirects=True) eq_(rv.status_code, 200) self.logout(client) self.login(client, DEFAULT_ADMIN_USER, 'password') rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) #Reset Password Admin rv = client.get('/users/action/resetpasswords/1', follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/psview/list') data = rv.data.decode('utf-8') def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1', field_float='0.12', field_date='2014-01-01'), follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u'test1') eq_(model.field_integer, 1) rv = client.post('/model1view/edit/1', data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model.field_string, u'test2') eq_(model.field_integer, 2) rv = client.get('/model1view/delete/1', follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(Model1).first() eq_(model, None) def test_model_crud_with_enum(self): """ Test Model add, delete, edit for Model with Enum Columns """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) data = {'enum1': u'e1'} if _has_enum: data['enum2'] = 'e1' rv = client.post('/modelwithenumsview/add', data=data, follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(ModelWithEnums).first() eq_(model.enum1, u'e1') if _has_enum: eq_(model.enum2, TestEnum.e1) data = {'enum1': u'e2'} if _has_enum: data['enum2'] = 'e2' rv = client.post('/modelwithenumsview/edit/1', data=data, follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(ModelWithEnums).first() eq_(model.enum1, u'e2') if _has_enum: eq_(model.enum2, TestEnum.e2) rv = client.get('/modelwithenumsview/delete/1', follow_redirects=True) eq_(rv.status_code, 200) model = self.db.session.query(ModelWithEnums).first() eq_(model, None) def test_formatted_cols(self): """ Test ModelView's formatters_columns """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() rv = client.get('/model1formattedview/list/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('FORMATTED_STRING' in data) rv = client.get('/model1formattedview/show/1') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('FORMATTED_STRING' in data) def test_model_redirects(self): """ Test Model redirects after add, delete, edit """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) model1 = Model1(field_string='Test Redirects') self.db.session.add(model1) model1.id = REDIRECT_OBJ_ID self.db.session.flush() rv = client.post('/model1viewwithredirects/add', data=dict(field_string='test_redirect', field_integer='1', field_float='0.12', field_date='2014-01-01'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('Test Redirects' in data) model_id = self.db.session.query(Model1).filter_by(field_string='test_redirect').first().id rv = client.post('/model1viewwithredirects/edit/{0}'.format(model_id), data=dict(field_string='test_redirect_2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) ok_('Test Redirects' in data) rv = client.get('/model1viewwithredirects/delete/{0}'.format(model_id), follow_redirects=True) eq_(rv.status_code, 200) ok_('Test Redirects' in data) def test_excluded_cols(self): """ Test add_exclude_columns, edit_exclude_columns, show_exclude_columns """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model22view/add') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_string' in data) ok_('field_integer' in data) ok_('field_float' in data) ok_('field_date' in data) ok_('excluded_string' not in data) self.insert_data2() rv = client.get('/model22view/edit/1') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_string' in data) ok_('field_integer' in data) ok_('field_float' in data) ok_('field_date' in data) ok_('excluded_string' not in data) rv = client.get('/model22view/show/1') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('Field String' in data) ok_('Field Integer' in data) ok_('Field Float' in data) ok_('Field Date' in data) ok_('Excluded String' not in data) def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get('/model2view/add') data = rv.data.decode('utf-8') ok_('G1' in data) ok_('G2' not in data) # Base filter string starts with rv = client.get('/model2view/edit/1') data = rv.data.decode('utf-8') ok_('G2' in data) ok_('G1' not in data) def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc', follow_redirects=True) # TODO: Fix this 405 error # eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED rv = client.post('/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc', follow_redirects=True) # TODO: Fix this 405 error # eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) model = self.db.session.query(Model1).all() eq_(len(model), 1) rv = client.post('/model1view/add', data=dict(field_string='', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) model = self.db.session.query(Model1).all() eq_(len(model), 1) def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) client.post('/model1view/add', data=dict(field_string='test2', field_integer='1'), follow_redirects=True) rv = client.post('/model1view/edit/1', data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post('/model1view/edit/1', data=dict(field_string='', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = self.db.session.query(Model1).all() eq_(len(models), 23) # Base filter string starts with rv = client.get('/model1filtered1view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) # Base filter integer equals rv = client.get('/model1filtered2view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model2view/list/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_method_value' in data) def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model1compactview/list/') eq_(rv.status_code, 200) def test_edit_add_form_action_prefix_for_compactCRUDMixin(self): """ Test form_action in add, form_action in edit (CompactCRUDMixin) """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) # Make sure we have something to edit. self.insert_data() prefix = '/some-prefix' base_url = 'http://localhost' + prefix session_form_action_key = 'Model1CompactView__session_form_action' with client as c: expected_form_action = prefix + '/model1compactview/add/?' c.get('/model1compactview/add/', base_url=base_url) ok_(session[session_form_action_key] == expected_form_action) expected_form_action = prefix + '/model1compactview/edit/1?' c.get('/model1compactview/edit/1', base_url=base_url) ok_(session[session_form_action_key] == expected_form_action) def test_charts_view(self): """ Test Various Chart views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() log.info("CHART TEST") rv = client.get('/model2chartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2groupbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2directbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2timechartview/chart/') eq_(rv.status_code, 200) # TODO: fix this # rv = client.get('/model2directchartview/chart/') #eq_(rv.status_code, 200) def test_master_detail_view(self): """ Test Master detail view """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model1masterview/list/') eq_(rv.status_code, 200) rv = client.get('/model1masterview/list/1') eq_(rv.status_code, 200) rv = client.get('/model1masterchartview/list/') eq_(rv.status_code, 200) rv = client.get('/model1masterchartview/list/1') eq_(rv.status_code, 200) def test_api_read(self): """ Testing the api/read endpoint """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() rv = client.get('/model1formattedview/api/read') eq_(rv.status_code, 200) data = json.loads(rv.data.decode('utf-8')) assert 'result' in data assert 'pks' in data assert len(data.get('result')) > 10 def test_api_create(self): """ Testing the api/create endpoint """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( '/model1view/api/create', data=dict(field_string='zzz'), follow_redirects=True) eq_(rv.status_code, 200) objs = self.db.session.query(Model1).all() eq_(len(objs), 1) def test_api_update(self): """ Validate that the api update endpoint updates [only] the fields in POST data """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() item = self.db.session.query(Model1).filter_by(id=1).one() field_integer_before = item.field_integer rv = client.put( '/model1view/api/update/1', data=dict(field_string='zzz'), follow_redirects=True) eq_(rv.status_code, 200) item = self.db.session.query(Model1).filter_by(id=1).one() eq_(item.field_string, 'zzz') eq_(item.field_integer, field_integer_before)
class FlaskTestCase(FABTestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["MONGODB_SETTINGS"] = {"DB": "test"} self.app.config["CSRF_ENABLED"] = False self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["WTF_CSRF_ENABLED"] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "G2"]] } add_form_query_rel_fields = {"group": [["field_string", FilterEqual, "G1"]]} add_exclude_columns = ["excluded_string"] class Model22View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [["field_string", FilterStartsWith, "a"]] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ { "group": "field_string", "series": [ ( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", ) ], } ] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ {"group": "field_string", "series": ["field_integer", "field_float"]} ] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view( Model1MasterChartView, "Model1MasterChart", category="Model1" ) self.appbuilder.add_view( Model1Filtered1View, "Model1Filtered1", category="Model1" ) self.appbuilder.add_view( Model1Filtered2View, "Model1Filtered2", category="Model1" ) self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role("Admin") try: self.appbuilder.sm.add_user( "admin", "admin", "user", "*****@*****.**", role_admin, "general" ) except Exception: pass def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) model.save() def insert_data2(self): models1 = [ Model1(field_string="G1"), Model1(field_string="G2"), Model1(field_string="G3"), ] for model1 in models1: try: model1.save() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2( field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1, ) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) model.save() except Exception as e: print("ERROR {0}".format(str(e))) def clean_data(self): Model1.drop_collection() Model2.drop_collection() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 26) # current minimal views are 26 def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get("/") data = rv.data.decode("utf-8") ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get("/model1view/list/") eq_(rv.status_code, 302) rv = client.get("/model2view/list/") eq_(rv.status_code, 302) # Login and list with admin self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/model1view/list/") eq_(rv.status_code, 200) rv = client.get("/model2view/list/") eq_(rv.status_code, 200) # Logout and and try to list self.browser_logout(client) rv = client.get("/model1view/list/") eq_(rv.status_code, 302) rv = client.get("/model2view/list/") eq_(rv.status_code, 302) # Invalid Login rv = self.browser_login(client, DEFAULT_ADMIN_USER, "password") data = rv.data.decode("utf-8") ok_(INVALID_LOGIN_STRING in data) def test_sec_reset_password(self): """ Test Security reset password """ from flask_appbuilder.security.mongoengine.models import User client = self.app.test_client() # Try Reset My password user = User.objects.filter(**{"username": "******"})[0] rv = client.get( "/users/action/resetmypassword/{0}".format(user.id), follow_redirects=True ) data = rv.data.decode("utf-8") ok_(ACCESS_IS_DENIED in data) # Reset My password rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get( "/users/action/resetmypassword/{0}".format(user.id), follow_redirects=True ) data = rv.data.decode("utf-8") ok_("Reset Password Form" in data) rv = client.post( "/resetmypassword/form", data=dict(password="******", conf_password="******"), follow_redirects=True, ) eq_(rv.status_code, 200) self.browser_logout(client) self.browser_login(client, DEFAULT_ADMIN_USER, "password") rv = client.post( "/resetmypassword/form", data=dict( password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD ), follow_redirects=True, ) eq_(rv.status_code, 200) # Reset Password Admin rv = client.get( "/users/action/resetpasswords/{0}".format(user.id), follow_redirects=True ) data = rv.data.decode("utf-8") ok_("Reset Password Form" in data) rv = client.post( "/resetmypassword/form", data=dict( password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD ), follow_redirects=True, ) eq_(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/psview/list") rv.data.decode("utf-8") def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/add", data=dict( field_string="test1", field_integer="1", field_float="0.12", field_date="2014-01-01 23:10:07", ), follow_redirects=True, ) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u"test1") eq_(model.field_integer, 1) model1 = Model1.objects(field_string="test1")[0] rv = client.post( "/model1view/edit/{0}".format(model1.id), data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u"test2") eq_(model.field_integer, 2) rv = client.get( "/model1view/delete/{0}".format(model.id), follow_redirects=True ) eq_(rv.status_code, 200) model = Model1.objects eq_(len(model), 0) self.clean_data() def test_excluded_cols(self): """ Test add_exclude_columns, edit_exclude_columns, show_exclude_columns """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/model22view/add") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_string" in data) ok_("field_integer" in data) ok_("field_float" in data) ok_("field_date" in data) ok_("excluded_string" not in data) self.insert_data2() model2 = Model2.objects[0] rv = client.get("/model22view/edit/{0}".format(model2.id)) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_string" in data) ok_("field_integer" in data) ok_("field_float" in data) ok_("field_date" in data) ok_("excluded_string" not in data) rv = client.get("/model22view/show/{0}".format(model2.id)) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("Field String" in data) ok_("Field Integer" in data) ok_("Field Float" in data) ok_("Field Date" in data) ok_("Excluded String" not in data) self.clean_data() def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get("/model2view/add") data = rv.data.decode("utf-8") ok_("G1" in data) ok_("G2" not in data) model2 = Model2.objects[0] # Base filter string starts with rv = client.get("/model2view/edit/{0}".format(model2.id)) data = rv.data.decode("utf-8") ok_("G2" in data) ok_("G1" not in data) self.clean_data() def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc", follow_redirects=True, ) # TODO: fix this 405 Method not allowed error # eq_(rv.status_code, 200) rv.data.decode("utf-8") # TODO # VALIDATE LIST IS ORDERED rv = client.post( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc", follow_redirects=True, ) # TODO: fix this 405 Method not allowed error # eq_(rv.status_code, 200) rv.data.decode("utf-8") # TODO # VALIDATE LIST IS ORDERED self.clean_data() def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="1"), follow_redirects=True, ) eq_(rv.status_code, 200) rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(UNIQUE_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) rv = client.post( "/model1view/add", data=dict(field_string="", field_integer="1"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(NOTNULL_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) self.clean_data() def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post( "/model1view/add", data=dict(field_string="test1", field_integer="1"), follow_redirects=True, ) model1 = Model1.objects(field_string="test1")[0] client.post( "/model1view/add", data=dict(field_string="test2", field_integer="1"), follow_redirects=True, ) rv = client.post( "/model1view/edit/{0}".format(model1.id), data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post( "/model1view/edit/{0}".format(model1.id), data=dict(field_string="", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(NOTNULL_VALIDATION_STRING in data) self.clean_data() def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = Model1.objects() eq_(len(models), 23) # Base filter string starts with rv = client.get("/model1filtered1view/list/") data = rv.data.decode("utf-8") ok_("atest" in data) ok_("btest" not in data) # Base filter integer equals rv = client.get("/model1filtered2view/list/") data = rv.data.decode("utf-8") ok_("atest" in data) ok_("btest" not in data) self.clean_data() def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model2view/list/") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_method_value" in data) self.clean_data() def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model1compactview/list/") eq_(rv.status_code, 200) self.clean_data()
index_template = "index.html" route_base = "/" @expose("/") def index(self): return self.render_template("index.html") @expose("/MP_verify_Wg4ba6rCnxiDaStD.txt") def verify(self): return send_file('static/MP_verify_Wg4ba6rCnxiDaStD.txt') @expose("/favicon.ico") def favicion(self): return send_file(app.root_path+"/static/img/favicon.ico") appbuilder = AppBuilder(app, db.session, indexview=CustomeIndexView) """ from sqlalchemy.engine import Engine from sqlalchemy import event #Only include this for SQLLite constraints @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() """
def create_app(config=None, session=None, testing=False, app_name="Airflow"): global app, appbuilder app = Flask(__name__) if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'): app.wsgi_app = ProxyFix(app.wsgi_app) app.secret_key = conf.get('webserver', 'SECRET_KEY') app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True) app.config['APP_NAME'] = app_name app.config['TESTING'] = testing app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SESSION_COOKIE_HTTPONLY'] = True app.config['SESSION_COOKIE_SECURE'] = conf.getboolean( 'webserver', 'COOKIE_SECURE') app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver', 'COOKIE_SAMESITE') if config: app.config.from_mapping(config) # Configure the JSON encoder used by `|tojson` filter from Flask app.json_encoder = AirflowJsonEncoder csrf.init_app(app) db = SQLA(app) from airflow import api api.load_auth() api.api_auth.init_app(app) # flake8: noqa: F841 cache = Cache(app=app, config={ 'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp' }) from airflow.www.blueprints import routes app.register_blueprint(routes) configure_logging() configure_manifest_files(app) with app.app_context(): from airflow.www.security import AirflowSecurityManager security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \ AirflowSecurityManager if not issubclass(security_manager_class, AirflowSecurityManager): raise Exception( """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager, not FAB's security manager.""") appbuilder = AppBuilder(app, db.session if not session else session, security_manager_class=security_manager_class, base_template='appbuilder/baselayout.html') def init_views(appbuilder): from airflow.www import views appbuilder.add_view_no_menu(views.Airflow()) appbuilder.add_view_no_menu(views.DagModelView()) appbuilder.add_view_no_menu(views.ConfigurationView()) appbuilder.add_view_no_menu(views.VersionView()) appbuilder.add_view(views.DagRunModelView, "DAG Runs", category="Browse", category_icon="fa-globe") appbuilder.add_view(views.JobModelView, "Jobs", category="Browse") appbuilder.add_view(views.LogModelView, "Logs", category="Browse") appbuilder.add_view(views.SlaMissModelView, "SLA Misses", category="Browse") appbuilder.add_view(views.TaskInstanceModelView, "Task Instances", category="Browse") appbuilder.add_link("Configurations", href='/configuration', category="Admin", category_icon="fa-user") appbuilder.add_view(views.ConnectionModelView, "Connections", category="Admin") appbuilder.add_view(views.PoolModelView, "Pools", category="Admin") appbuilder.add_view(views.VariableModelView, "Variables", category="Admin") appbuilder.add_view(views.XComModelView, "XComs", category="Admin") appbuilder.add_link("Documentation", href='https://airflow.apache.org/', category="Docs", category_icon="fa-cube") appbuilder.add_link("GitHub", href='https://github.com/apache/airflow', category="Docs") appbuilder.add_link('Version', href='/version', category='About', category_icon='fa-th') def integrate_plugins(): """Integrate plugins to the context""" from airflow.plugins_manager import ( flask_appbuilder_views, flask_appbuilder_menu_links) for v in flask_appbuilder_views: log.debug("Adding view %s", v["name"]) appbuilder.add_view(v["view"], v["name"], category=v["category"]) for ml in sorted(flask_appbuilder_menu_links, key=lambda x: x["name"]): log.debug("Adding menu link %s", ml["name"]) appbuilder.add_link(ml["name"], href=ml["href"], category=ml["category"], category_icon=ml["category_icon"]) integrate_plugins() # Garbage collect old permissions/views after they have been modified. # Otherwise, when the name of a view or menu is changed, the framework # will add the new Views and Menus names to the backend, but will not # delete the old ones. def init_plugin_blueprints(app): from airflow.plugins_manager import flask_blueprints for bp in flask_blueprints: log.debug("Adding blueprint %s:%s", bp["name"], bp["blueprint"].import_name) app.register_blueprint(bp["blueprint"]) init_views(appbuilder) init_plugin_blueprints(app) security_manager = appbuilder.sm security_manager.sync_roles() from airflow.www.api.experimental import endpoints as e # required for testing purposes otherwise the module retains # a link to the default_auth if app.config['TESTING']: import importlib importlib.reload(e) app.register_blueprint(e.api_experimental, url_prefix='/api/experimental') @app.context_processor def jinja_globals(): return { 'hostname': socket.getfqdn(), 'navbar_color': conf.get('webserver', 'NAVBAR_COLOR'), } @app.teardown_appcontext def shutdown_session(exception=None): settings.Session.remove() return app, appbuilder
from flask import Flask, redirect from flask_appbuilder import SQLA, AppBuilder, IndexView from flask_appbuilder.baseviews import expose from flask_migrate import Migrate APP_DIR = os.path.dirname(__file__) # Logging configuration logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s') logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_pyfile('config.py') db = SQLA(app) migrate = Migrate(app, db, directory=APP_DIR + "/migrations") class MyIndexView(IndexView): @expose('/') def index(self): return redirect('/binders/welcome') appbuilder = AppBuilder(app=app, session=db.session, base_template='binders/base.html', indexview=MyIndexView) from binders import views # noqa
class TestSecurity(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['SECRET_KEY'] = 'secret_key' self.app.config['CSRF_ENABLED'] = False self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session, security_manager_class=AirflowSecurityManager) self.security_manager = self.appbuilder.sm self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") role_admin = self.security_manager.find_role('Admin') self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') log.debug("Complete setup!") def expect_user_is_in_role(self, user, rolename): self.security_manager.init_role(rolename, [], []) role = self.security_manager.find_role(rolename) if not role: self.security_manager.add_role(rolename) role = self.security_manager.find_role(rolename) user.roles = [role] self.security_manager.update_user(user) def assert_user_has_dag_perms(self, perms, dag_id): for perm in perms: self.assertTrue( self._has_dag_perm(perm, dag_id), "User should have '{}' on DAG '{}'".format(perm, dag_id)) def assert_user_does_not_have_dag_perms(self, dag_id, perms): for perm in perms: self.assertFalse( self._has_dag_perm(perm, dag_id), "User should not have '{}' on DAG '{}'".format(perm, dag_id)) def _has_dag_perm(self, perm, dag_id): return self.security_manager.has_access( perm, dag_id, self.user) def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("Complete teardown!") def test_init_role_baseview(self): role_name = 'MyRole1' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.appbuilder.sm.find_role(role_name) self.assertIsNotNone(role) self.assertEqual(len(role_perms), len(role.permissions)) def test_init_role_modelview(self): role_name = 'MyRole2' role_perms = ['can_list', 'can_show', 'can_add', 'can_edit', 'can_delete'] role_vms = ['SomeModelView'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.appbuilder.sm.find_role(role_name) self.assertIsNotNone(role) self.assertEqual(len(role_perms), len(role.permissions)) def test_update_and_verify_permission_role(self): role_name = 'Test_Role' self.security_manager.init_role(role_name, [], []) role = self.security_manager.find_role(role_name) perm = self.security_manager.\ find_permission_view_menu('can_edit', 'RoleModelView') self.security_manager.add_permission_role(role, perm) role_perms_len = len(role.permissions) self.security_manager.init_role(role_name, [], []) new_role_perms_len = len(role.permissions) self.assertEqual(role_perms_len, new_role_perms_len) def test_get_user_roles(self): user = mock.MagicMock() user.is_anonymous = False roles = self.appbuilder.sm.find_role('Admin') user.roles = roles self.assertEqual(self.security_manager.get_user_roles(user), roles) @mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles') def test_get_all_permissions_views(self, mock_get_user_roles): role_name = 'MyRole1' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.security_manager.find_role(role_name) mock_get_user_roles.return_value = [role] self.assertEqual(self.security_manager .get_all_permissions_views(), {('can_some_action', 'SomeBaseView')}) mock_get_user_roles.return_value = [] self.assertEqual(len(self.security_manager .get_all_permissions_views()), 0) @mock.patch('airflow.www.security.AirflowSecurityManager' '.get_all_permissions_views') @mock.patch('airflow.www.security.AirflowSecurityManager' '.get_user_roles') def test_get_accessible_dag_ids(self, mock_get_user_roles, mock_get_all_permissions_views): user = mock.MagicMock() role_name = 'MyRole1' role_perms = ['can_dag_read'] role_vms = ['dag_id'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.security_manager.find_role(role_name) user.roles = [role] user.is_anonymous = False mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')} mock_get_user_roles.return_value = [role] self.assertEqual(self.security_manager .get_accessible_dag_ids(user), set(['dag_id'])) @mock.patch('airflow.www.security.AirflowSecurityManager._has_view_access') def test_has_access(self, mock_has_view_access): user = mock.MagicMock() user.is_anonymous = False mock_has_view_access.return_value = True self.assertTrue(self.security_manager.has_access('perm', 'view', user)) def test_sync_perm_for_dag_creates_permissions_on_view_menus(self): test_dag_id = 'TEST_DAG' self.security_manager.sync_perm_for_dag(test_dag_id, access_control=None) for dag_perm in self.security_manager.DAG_PERMS: self.assertIsNotNone(self.security_manager. find_permission_view_menu(dag_perm, test_dag_id)) @mock.patch('airflow.www.security.AirflowSecurityManager._has_perm') @mock.patch('airflow.www.security.AirflowSecurityManager._has_role') def test_has_all_dag_access(self, mock_has_role, mock_has_perm): mock_has_role.return_value = True self.assertTrue(self.security_manager.has_all_dags_access()) mock_has_role.return_value = False mock_has_perm.return_value = False self.assertFalse(self.security_manager.has_all_dags_access()) mock_has_perm.return_value = True self.assertTrue(self.security_manager.has_all_dags_access()) def test_access_control_with_non_existent_role(self): with self.assertRaises(AirflowException) as context: self.security_manager.sync_perm_for_dag( dag_id='access-control-test', access_control={ 'this-role-does-not-exist': ['can_dag_edit', 'can_dag_read'] }) self.assertIn("role does not exist", str(context.exception)) def test_access_control_with_invalid_permission(self): invalid_permissions = [ 'can_varimport', # a real permission, but not a member of DAG_PERMS 'can_eat_pudding', # clearly not a real permission ] for permission in invalid_permissions: self.expect_user_is_in_role(self.user, rolename='team-a') with self.assertRaises(AirflowException) as context: self.security_manager.sync_perm_for_dag( 'access_control_test', access_control={ 'team-a': {permission} }) self.assertIn("invalid permissions", str(context.exception)) def test_access_control_is_set_on_init(self): self.expect_user_is_in_role(self.user, rolename='team-a') self.security_manager.sync_perm_for_dag( 'access_control_test', access_control={ 'team-a': ['can_dag_edit', 'can_dag_read'] }) self.assert_user_has_dag_perms( perms=['can_dag_edit', 'can_dag_read'], dag_id='access_control_test', ) self.expect_user_is_in_role(self.user, rolename='NOT-team-a') self.assert_user_does_not_have_dag_perms( perms=['can_dag_edit', 'can_dag_read'], dag_id='access_control_test', ) def test_access_control_stale_perms_are_revoked(self): READ_WRITE = {'can_dag_read', 'can_dag_edit'} READ_ONLY = {'can_dag_read'} self.expect_user_is_in_role(self.user, rolename='team-a') self.security_manager.sync_perm_for_dag( 'access_control_test', access_control={'team-a': READ_WRITE}) self.assert_user_has_dag_perms( perms=READ_WRITE, dag_id='access_control_test', ) self.security_manager.sync_perm_for_dag( 'access_control_test', access_control={'team-a': READ_ONLY}) self.assert_user_has_dag_perms( perms=['can_dag_read'], dag_id='access_control_test', ) self.assert_user_does_not_have_dag_perms( perms=['can_dag_edit'], dag_id='access_control_test', ) def test_no_additional_dag_permission_views_created(self): ab_perm_view_role = sqla_models.assoc_permissionview_role self.security_manager.sync_roles() num_pv_before = self.db.session().query(ab_perm_view_role).count() self.security_manager.sync_roles() num_pv_after = self.db.session().query(ab_perm_view_role).count() self.assertEqual(num_pv_before, num_pv_after) def test_override_role_vm(self): test_security_manager = TestSecurityManager(appbuilder=self.appbuilder) self.assertEqual(len(test_security_manager.VIEWER_VMS), 1) self.assertEqual(test_security_manager.VIEWER_VMS, {'Airflow'})
import logging from flask import Flask from flask_appbuilder import AppBuilder from flask_sqlalchemy import SQLAlchemy from flask_appbuilder.models import SQLA from sqlalchemy import event logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s') logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_object('config') db = SQLA(app) session = db.session appbuilder = AppBuilder(app, session) """ Only include this for SQLLite constraints @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() """ from app.models import * db.create_all() from app import views
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['MONGODB_SETTINGS'] = {'DB': 'test'} self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = ['field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string'] edit_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G2']]} add_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G1']]} class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ['field_string','field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':[(aggregate_sum,'field_integer', aggregate_avg, 'field_integer', aggregate_count,'field_integer') ] } ] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':['field_integer','field_float'] } ] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role('Admin') try: self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') except: pass
def create_app(config=None, session=None, testing=False, app_name="Airflow"): global app, appbuilder app = Flask(__name__) if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'): app.wsgi_app = ProxyFix(app.wsgi_app, x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1), x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1), x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1), x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1), x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)) app.secret_key = conf.get('webserver', 'SECRET_KEY') session_lifetime_days = conf.getint('webserver', 'SESSION_LIFETIME_DAYS', fallback=30) app.config['PERMANENT_SESSION_LIFETIME'] = timedelta( days=session_lifetime_days) app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True) app.config['APP_NAME'] = app_name app.config['TESTING'] = testing app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SESSION_COOKIE_HTTPONLY'] = True app.config['SESSION_COOKIE_SECURE'] = conf.getboolean( 'webserver', 'COOKIE_SECURE') app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver', 'COOKIE_SAMESITE') if config: app.config.from_mapping(config) # Configure the JSON encoder used by `|tojson` filter from Flask app.json_encoder = AirflowJsonEncoder csrf.init_app(app) db = SQLA(app) from airflow import api api.load_auth() api.API_AUTH.api_auth.init_app(app) Cache(app=app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'}) from airflow.www.blueprints import routes app.register_blueprint(routes) configure_logging() configure_manifest_files(app) with app.app_context(): from airflow.www.security import AirflowSecurityManager security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \ AirflowSecurityManager if not issubclass(security_manager_class, AirflowSecurityManager): raise Exception( """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager, not FAB's security manager.""") appbuilder = AppBuilder(app, db.session if not session else session, security_manager_class=security_manager_class, base_template='airflow/master.html', update_perms=conf.getboolean( 'webserver', 'UPDATE_FAB_PERMS')) def init_views(appbuilder): from airflow.www import views # Remove the session from scoped_session registry to avoid # reusing a session with a disconnected connection appbuilder.session.remove() appbuilder.add_view_no_menu(views.Airflow()) appbuilder.add_view_no_menu(views.DagModelView()) appbuilder.add_view(views.DagRunModelView, "DAG Runs", category="Browse", category_icon="fa-globe") appbuilder.add_view(views.JobModelView, "Jobs", category="Browse") appbuilder.add_view(views.LogModelView, "Logs", category="Browse") appbuilder.add_view(views.SlaMissModelView, "SLA Misses", category="Browse") appbuilder.add_view(views.TaskInstanceModelView, "Task Instances", category="Browse") appbuilder.add_view(views.ConfigurationView, "Configurations", category="Admin", category_icon="fa-user") appbuilder.add_view(views.ConnectionModelView, "Connections", category="Admin") appbuilder.add_view(views.PoolModelView, "Pools", category="Admin") appbuilder.add_view(views.VariableModelView, "Variables", category="Admin") appbuilder.add_view(views.XComModelView, "XComs", category="Admin") if "dev" in version.version: airflow_doc_site = "https://airflow.readthedocs.io/en/latest" else: airflow_doc_site = 'https://airflow.apache.org/docs/{}'.format( version.version) appbuilder.add_link("Website", href='https://airflow.apache.org', category="Docs", category_icon="fa-globe") appbuilder.add_link("Documentation", href=airflow_doc_site, category="Docs", category_icon="fa-cube") appbuilder.add_link("GitHub", href='https://github.com/apache/airflow', category="Docs") appbuilder.add_view(views.VersionView, 'Version', category='About', category_icon='fa-th') def integrate_plugins(): """Integrate plugins to the context""" from airflow.plugins_manager import ( flask_appbuilder_views, flask_appbuilder_menu_links) for v in flask_appbuilder_views: log.debug("Adding view %s", v["name"]) appbuilder.add_view(v["view"], v["name"], category=v["category"]) for ml in sorted(flask_appbuilder_menu_links, key=lambda x: x["name"]): log.debug("Adding menu link %s", ml["name"]) appbuilder.add_link(ml["name"], href=ml["href"], category=ml["category"], category_icon=ml["category_icon"]) integrate_plugins() # Garbage collect old permissions/views after they have been modified. # Otherwise, when the name of a view or menu is changed, the framework # will add the new Views and Menus names to the backend, but will not # delete the old ones. def init_plugin_blueprints(app): from airflow.plugins_manager import flask_blueprints for bp in flask_blueprints: log.debug("Adding blueprint %s:%s", bp["name"], bp["blueprint"].import_name) app.register_blueprint(bp["blueprint"]) init_views(appbuilder) init_plugin_blueprints(app) if conf.getboolean('webserver', 'UPDATE_FAB_PERMS'): security_manager = appbuilder.sm security_manager.sync_roles() from airflow.www.api.experimental import endpoints as e # required for testing purposes otherwise the module retains # a link to the default_auth if app.config['TESTING']: import importlib importlib.reload(e) app.register_blueprint(e.api_experimental, url_prefix='/api/experimental') @app.context_processor def jinja_globals(): # pylint: disable=unused-variable globals = { 'hostname': socket.getfqdn() if conf.getboolean( 'webserver', 'EXPOSE_HOSTNAME', fallback=True) else 'redact', 'navbar_color': conf.get('webserver', 'NAVBAR_COLOR'), 'log_fetch_delay_sec': conf.getint('webserver', 'log_fetch_delay_sec', fallback=2), 'log_auto_tailing_offset': conf.getint('webserver', 'log_auto_tailing_offset', fallback=30), 'log_animation_speed': conf.getint('webserver', 'log_animation_speed', fallback=1000) } if 'analytics_tool' in conf.getsection('webserver'): globals.update({ 'analytics_tool': conf.get('webserver', 'ANALYTICS_TOOL'), 'analytics_id': conf.get('webserver', 'ANALYTICS_ID') }) return globals @app.before_request def before_request(): _force_log_out_after = conf.getint('webserver', 'FORCE_LOG_OUT_AFTER', fallback=0) if _force_log_out_after > 0: flask.session.permanent = True app.permanent_session_lifetime = datetime.timedelta( minutes=_force_log_out_after) flask.session.modified = True flask.g.user = flask_login.current_user @app.after_request def apply_caching(response): _x_frame_enabled = conf.getboolean('webserver', 'X_FRAME_ENABLED', fallback=True) if not _x_frame_enabled: response.headers["X-Frame-Options"] = "DENY" return response @app.teardown_appcontext def shutdown_session(exception=None): # pylint: disable=unused-variable settings.Session.remove() @app.before_request def make_session_permanent(): flask_session.permanent = True return app, appbuilder
cfg = { 'SQLALCHEMY_DATABASE_URI': 'postgresql:///test', 'CSRF_ENABLED': False, 'WTF_CSRF_ENABLED': False, 'SECRET_KEY': 'bla', 'ADDON_MANAGERS': ['fab_addon_geoalchemy.manager.GeoAlchemyManager'] } app = Flask('testapp') app.config.update(cfg) engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI']) metadata = MetaData(bind=engine) db = SQLAlchemy(app, metadata=metadata) db.session.commit() appbuilder = AppBuilder(app, db.session) class Observation(db.Model): id = Column(Integer, primary_key=True) name = Column(String) location = Column(Geometry(geometry_type='POINT', srid=4326)) location2 = Column(Geometry(geometry_type='POINT', srid=3857)) def __repr__(self): if self.name: return self.name else: return 'Person Type %s' % self.id
class TestSecurity(unittest.TestCase): def setUp(self): self.app = Flask(__name__) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['SECRET_KEY'] = 'secret_key' self.app.config['CSRF_ENABLED'] = False self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session, security_manager_class=AirflowSecurityManager) self.security_manager = self.appbuilder.sm self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews") self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews") role_admin = self.security_manager.find_role('Admin') self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') log.debug("Complete setup!") def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("Complete teardown!") def test_init_role_baseview(self): role_name = 'MyRole1' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.appbuilder.sm.find_role(role_name) self.assertIsNotNone(role) self.assertEqual(len(role_perms), len(role.permissions)) def test_init_role_modelview(self): role_name = 'MyRole2' role_perms = ['can_list', 'can_show', 'can_add', 'can_edit', 'can_delete'] role_vms = ['SomeModelView'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.appbuilder.sm.find_role(role_name) self.assertIsNotNone(role) self.assertEqual(len(role_perms), len(role.permissions)) def test_get_user_roles(self): user = mock.MagicMock() user.is_anonymous = False roles = self.appbuilder.sm.find_role('Admin') user.roles = roles self.assertEqual(self.security_manager.get_user_roles(user), roles) @mock.patch('airflow.www_rbac.security.AirflowSecurityManager.get_user_roles') def test_get_all_permissions_views(self, mock_get_user_roles): role_name = 'MyRole1' role_perms = ['can_some_action'] role_vms = ['SomeBaseView'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.security_manager.find_role(role_name) mock_get_user_roles.return_value = [role] self.assertEqual(self.security_manager .get_all_permissions_views(), {('can_some_action', 'SomeBaseView')}) mock_get_user_roles.return_value = [] self.assertEquals(len(self.security_manager .get_all_permissions_views()), 0) @mock.patch('airflow.www_rbac.security.AirflowSecurityManager' '.get_all_permissions_views') @mock.patch('airflow.www_rbac.security.AirflowSecurityManager' '.get_user_roles') def test_get_accessible_dag_ids(self, mock_get_user_roles, mock_get_all_permissions_views): user = mock.MagicMock() role_name = 'MyRole1' role_perms = ['can_dag_read'] role_vms = ['dag_id'] self.security_manager.init_role(role_name, role_vms, role_perms) role = self.security_manager.find_role(role_name) user.roles = [role] user.is_anonymous = False mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')} mock_get_user_roles.return_value = [role] self.assertEquals(self.security_manager .get_accessible_dag_ids(user), set(['dag_id'])) @mock.patch('airflow.www_rbac.security.AirflowSecurityManager._has_view_access') def test_has_access(self, mock_has_view_access): user = mock.MagicMock() user.is_anonymous = False mock_has_view_access.return_value = True self.assertTrue(self.security_manager.has_access('perm', 'view', user)) def test_sync_perm_for_dag(self): test_dag_id = 'TEST_DAG' self.security_manager.sync_perm_for_dag(test_dag_id) for dag_perm in dag_perms: self.assertIsNotNone(self.security_manager. find_permission_view_menu(dag_perm, test_dag_id)) @mock.patch('airflow.www_rbac.security.AirflowSecurityManager._has_perm') @mock.patch('airflow.www_rbac.security.AirflowSecurityManager._has_role') def test_has_all_dag_access(self, mock_has_role, mock_has_perm): mock_has_role.return_value = True self.assertTrue(self.security_manager.has_all_dags_access()) mock_has_role.return_value = False mock_has_perm.return_value = False self.assertFalse(self.security_manager.has_all_dags_access()) mock_has_perm.return_value = True self.assertTrue(self.security_manager.has_all_dags_access())
class FlaskTestCase(FABTestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["MONGODB_SETTINGS"] = {"DB": "test"} self.app.config["CSRF_ENABLED"] = False self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["WTF_CSRF_ENABLED"] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "G2"]] } add_form_query_rel_fields = {"group": [["field_string", FilterEqual, "G1"]]} add_exclude_columns = ["excluded_string"] class Model22View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [["field_string", FilterStartsWith, "a"]] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ { "group": "field_string", "series": [ ( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", ) ], } ] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ {"group": "field_string", "series": ["field_integer", "field_float"]} ] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view( Model1MasterChartView, "Model1MasterChart", category="Model1" ) self.appbuilder.add_view( Model1Filtered1View, "Model1Filtered1", category="Model1" ) self.appbuilder.add_view( Model1Filtered2View, "Model1Filtered2", category="Model1" ) self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role("Admin") try: self.appbuilder.sm.add_user( "admin", "admin", "user", "*****@*****.**", role_admin, "general" ) except Exception: pass def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) model.save() def insert_data2(self): models1 = [ Model1(field_string="G1"), Model1(field_string="G2"), Model1(field_string="G3"), ] for model1 in models1: try: model1.save() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2( field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1, ) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) model.save() except Exception as e: print("ERROR {0}".format(str(e))) def clean_data(self): Model1.drop_collection() Model2.drop_collection() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 27) # current minimal views are 26 def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get("/") data = rv.data.decode("utf-8") ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get("/model1view/list/") eq_(rv.status_code, 302) rv = client.get("/model2view/list/") eq_(rv.status_code, 302) # Login and list with admin self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/model1view/list/") eq_(rv.status_code, 200) rv = client.get("/model2view/list/") eq_(rv.status_code, 200) # Logout and and try to list self.browser_logout(client) rv = client.get("/model1view/list/") eq_(rv.status_code, 302) rv = client.get("/model2view/list/") eq_(rv.status_code, 302) # Invalid Login rv = self.browser_login(client, DEFAULT_ADMIN_USER, "password") data = rv.data.decode("utf-8") ok_(INVALID_LOGIN_STRING in data) def test_sec_reset_password(self): """ Test Security reset password """ from flask_appbuilder.security.mongoengine.models import User client = self.app.test_client() # Try Reset My password user = User.objects.filter(**{"username": "******"})[0] rv = client.get( "/users/action/resetmypassword/{0}".format(user.id), follow_redirects=True ) # Werkzeug update to 0.15.X sends this action to wrong redirect # Old test was: # data = rv.data.decode("utf-8") # ok_(ACCESS_IS_DENIED in data) self.assertEqual(rv.status_code, 404) # Reset My password rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get( "/users/action/resetmypassword/{0}".format(user.id), follow_redirects=True ) data = rv.data.decode("utf-8") self.assertIn("Reset Password Form", data) rv = client.post( "/resetmypassword/form", data=dict(password="******", conf_password="******"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) self.browser_logout(client) self.browser_login(client, DEFAULT_ADMIN_USER, "password") rv = client.post( "/resetmypassword/form", data=dict( password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD ), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) # Reset Password Admin rv = client.get( "/users/action/resetpasswords/{0}".format(user.id), follow_redirects=True ) data = rv.data.decode("utf-8") self.assertIn("Reset Password Form", data) rv = client.post( "/resetmypassword/form", data=dict( password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD ), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/add", data=dict( field_string="test1", field_integer="1", field_float="0.12", field_date="2014-01-01 23:10:07", ), follow_redirects=True, ) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u"test1") eq_(model.field_integer, 1) model1 = Model1.objects(field_string="test1")[0] rv = client.post( "/model1view/edit/{0}".format(model1.id), data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u"test2") eq_(model.field_integer, 2) rv = client.get( "/model1view/delete/{0}".format(model.id), follow_redirects=True ) eq_(rv.status_code, 200) model = Model1.objects eq_(len(model), 0) self.clean_data() def test_excluded_cols(self): """ Test add_exclude_columns, edit_exclude_columns, show_exclude_columns """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get("/model22view/add") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_string" in data) ok_("field_integer" in data) ok_("field_float" in data) ok_("field_date" in data) ok_("excluded_string" not in data) self.insert_data2() model2 = Model2.objects[0] rv = client.get("/model22view/edit/{0}".format(model2.id)) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_string" in data) ok_("field_integer" in data) ok_("field_float" in data) ok_("field_date" in data) ok_("excluded_string" not in data) rv = client.get("/model22view/show/{0}".format(model2.id)) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("Field String" in data) ok_("Field Integer" in data) ok_("Field Float" in data) ok_("Field Date" in data) ok_("Excluded String" not in data) self.clean_data() def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get("/model2view/add") data = rv.data.decode("utf-8") ok_("G1" in data) ok_("G2" not in data) model2 = Model2.objects[0] # Base filter string starts with rv = client.get("/model2view/edit/{0}".format(model2.id)) data = rv.data.decode("utf-8") ok_("G2" in data) ok_("G1" not in data) self.clean_data() def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc", follow_redirects=True, ) # TODO: fix this 405 Method not allowed error # eq_(rv.status_code, 200) rv.data.decode("utf-8") # TODO # VALIDATE LIST IS ORDERED rv = client.post( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc", follow_redirects=True, ) # TODO: fix this 405 Method not allowed error # eq_(rv.status_code, 200) rv.data.decode("utf-8") # TODO # VALIDATE LIST IS ORDERED self.clean_data() def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="1"), follow_redirects=True, ) eq_(rv.status_code, 200) rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(UNIQUE_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) rv = client.post( "/model1view/add", data=dict(field_string="", field_integer="1"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(NOTNULL_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) self.clean_data() def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post( "/model1view/add", data=dict(field_string="test1", field_integer="1"), follow_redirects=True, ) model1 = Model1.objects(field_string="test1")[0] client.post( "/model1view/add", data=dict(field_string="test2", field_integer="1"), follow_redirects=True, ) rv = client.post( "/model1view/edit/{0}".format(model1.id), data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post( "/model1view/edit/{0}".format(model1.id), data=dict(field_string="", field_integer="2"), follow_redirects=True, ) eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_(NOTNULL_VALIDATION_STRING in data) self.clean_data() def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = Model1.objects() eq_(len(models), 23) # Base filter string starts with rv = client.get("/model1filtered1view/list/") data = rv.data.decode("utf-8") ok_("atest" in data) ok_("btest" not in data) # Base filter integer equals rv = client.get("/model1filtered2view/list/") data = rv.data.decode("utf-8") ok_("atest" in data) ok_("btest" not in data) self.clean_data() def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model2view/list/") eq_(rv.status_code, 200) data = rv.data.decode("utf-8") ok_("field_method_value" in data) self.clean_data() def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.browser_login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get("/model1compactview/list/") eq_(rv.status_code, 200) self.clean_data()
with open(self.manifest_file, "r") as f: # the manifest includes non-entry files we only need entries in # templates full_manifest = json.load(f) self.manifest = full_manifest.get("entrypoints", {}) except Exception: # pylint: disable=broad-except pass def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]: if self.app and self.app.debug: self.parse_manifest_json() return self.manifest.get(bundle, {}).get(asset_type, []) APP_DIR = os.path.dirname(__file__) appbuilder = AppBuilder(update_perms=False) async_query_manager = AsyncQueryManager() cache_manager = CacheManager() celery_app = celery.Celery() csrf = CSRFProtect() db = SQLA() _event_logger: Dict[str, Any] = {} encrypted_field_factory = EncryptedFieldFactory() event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() machine_auth_provider_factory = MachineAuthProviderFactory() manifest_processor = UIManifestProcessor(APP_DIR) migrate = Migrate() results_backend_manager = ResultsBackendManager() security_manager = LocalProxy(lambda: appbuilder.sm) talisman = Talisman()
import logging from flask import Flask from flask_appbuilder import AppBuilder, SQLA from .sec import MySec """ Logging configuration """ logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(message)s") logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_object("config") db = SQLA(app) appbuilder = AppBuilder(app, db.session, security_manager_class=MySec) """ from sqlalchemy.engine import Engine from sqlalchemy import event #Only include this for SQLLite constraints @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() """ from . import views
'SQLALCHEMY_DATABASE_URI': 'sqlite:///', 'CSRF_ENABLED': False, 'IMG_UPLOAD_URL': '/', 'IMG_UPLOAD_FOLDER': '/tmp/', 'WTF_CSRF_ENABLED': False, 'SECRET_KEY': 'bla' } app = Flask('wtforms_jsonschema2_fab_testing') app.config.update(cfg) engine = create_engine(app.config['SQLALCHEMY_DATABASE_URI']) metadata = MetaData(bind=engine) db = SQLAlchemy(app, metadata=metadata) ctx = app.app_context() ctx.push() appbuilder = AppBuilder(app, db.session) db.session.commit() class CauseOfDeathEnum(enum.Enum): bycatch = 'bycatch' stranding = 'stranding' class BaseObservation(db.Model): __tablename__ = 'observation' id = Column(Integer, primary_key=True) alive = Column(Boolean) length = Column(Numeric, nullable=False) dead_observation = relationship('DeadObservation', back_populates='base_observation',
class FlaskTestCase(unittest.TestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['MONGODB_SETTINGS'] = {'DB': 'test'} self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] edit_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G2']] } add_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G1']] } add_exclude_columns = ['excluded_string'] class Model22View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] add_exclude_columns = ['excluded_string'] edit_exclude_columns = ['excluded_string'] show_exclude_columns = ['excluded_string'] class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ['field_string', 'field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': [(aggregate_sum, 'field_integer', aggregate_avg, 'field_integer', aggregate_count, 'field_integer')] }] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': ['field_integer', 'field_float'] }] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role('Admin') try: self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') except: pass def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def login(self, client, username, password): # Login with default admin return client.post('/login/', data=dict(username=username, password=password), follow_redirects=True) def logout(self, client): return client.get('/logout/') def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) model.save() def insert_data2(self): models1 = [ Model1(field_string='G1'), Model1(field_string='G2'), Model1(field_string='G3') ] for model1 in models1: try: model1.save() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2(field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) model.save() except Exception as e: print("ERROR {0}".format(str(e))) def clean_data(self): Model1.drop_collection() Model2.drop_collection() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 24) # current minimal views are 12 def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get('/') data = rv.data.decode('utf-8') ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Login and list with admin self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model1view/list/') eq_(rv.status_code, 200) rv = client.get('/model2view/list/') eq_(rv.status_code, 200) # Logout and and try to list self.logout(client) rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Invalid Login rv = self.login(client, DEFAULT_ADMIN_USER, 'password') data = rv.data.decode('utf-8') ok_(INVALID_LOGIN_STRING in data) def test_sec_reset_password(self): """ Test Security reset password """ from flask_appbuilder.security.mongoengine.models import User client = self.app.test_client() # Try Reset My password user = User.objects.filter(**{'username': '******'})[0] rv = client.get('/users/action/resetmypassword/{0}'.format(user.id), follow_redirects=True) data = rv.data.decode('utf-8') ok_(ACCESS_IS_DENIED in data) #Reset My password rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/users/action/resetmypassword/{0}'.format(user.id), follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password='******', conf_password='******'), follow_redirects=True) eq_(rv.status_code, 200) self.logout(client) self.login(client, DEFAULT_ADMIN_USER, 'password') rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) #Reset Password Admin rv = client.get('/users/action/resetpasswords/{0}'.format(user.id), follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/psview/list') data = rv.data.decode('utf-8') def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1', field_float='0.12', field_date='2014-01-01 23:10:07'), follow_redirects=True) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u'test1') eq_(model.field_integer, 1) model1 = Model1.objects(field_string='test1')[0] rv = client.post('/model1view/edit/{0}'.format(model1.id), data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u'test2') eq_(model.field_integer, 2) rv = client.get('/model1view/delete/{0}'.format(model.id), follow_redirects=True) eq_(rv.status_code, 200) model = Model1.objects eq_(len(model), 0) self.clean_data() def test_excluded_cols(self): """ Test add_exclude_columns, edit_exclude_columns, show_exclude_columns """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model22view/add') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_string' in data) ok_('field_integer' in data) ok_('field_float' in data) ok_('field_date' in data) ok_('excluded_string' not in data) self.insert_data2() model2 = Model2.objects[0] rv = client.get('/model22view/edit/{0}'.format(model2.id)) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_string' in data) ok_('field_integer' in data) ok_('field_float' in data) ok_('field_date' in data) ok_('excluded_string' not in data) rv = client.get('/model22view/show/{0}'.format(model2.id)) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('Field String' in data) ok_('Field Integer' in data) ok_('Field Float' in data) ok_('Field Date' in data) ok_('Excluded String' not in data) self.clean_data() def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get('/model2view/add') data = rv.data.decode('utf-8') ok_('G1' in data) ok_('G2' not in data) model2 = Model2.objects[0] # Base filter string starts with rv = client.get('/model2view/edit/{0}'.format(model2.id)) data = rv.data.decode('utf-8') ok_('G2' in data) ok_('G1' not in data) self.clean_data() def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post( '/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc', follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED rv = client.post( '/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc', follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED self.clean_data() def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) rv = client.post('/model1view/add', data=dict(field_string='', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) self.clean_data() def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) model1 = Model1.objects(field_string='test1')[0] client.post('/model1view/add', data=dict(field_string='test2', field_integer='1'), follow_redirects=True) rv = client.post('/model1view/edit/{0}'.format(model1.id), data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post('/model1view/edit/{0}'.format(model1.id), data=dict(field_string='', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) self.clean_data() def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = Model1.objects() eq_(len(models), 23) # Base filter string starts with rv = client.get('/model1filtered1view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) # Base filter integer equals rv = client.get('/model1filtered2view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) self.clean_data() def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model2view/list/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_method_value' in data) self.clean_data() def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model1compactview/list/') eq_(rv.status_code, 200) self.clean_data() def test_charts_view(self): """ Test Various Chart views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() log.info("CHART TEST") rv = client.get('/model2groupbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2directbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2directchartview/chart/') #eq_(rv.status_code, 200) self.clean_data() """
import logging from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.security.mongoengine.manager import SecurityManager from flask_mongoengine import MongoEngine from app import mysecurity from .mysecurity import MySecurityManager logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(message)s") logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_object("config") dbmongo = MongoEngine(app) appbuilder = AppBuilder(app, security_manager_class=MySecurityManager) from app import models, views
import os import logging from flask import Flask from flask_appbuilder import SQLA, AppBuilder from app.index import MyIndexView logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s') logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_object('config') db = SQLA(app) appbuilder = AppBuilder(app, db.session, indexview=MyIndexView) from app import views
class APITestCase(FABTestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder import ModelRestApi self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False self.app.config["FAB_API_MAX_PAGE_SIZE"] = MAX_PAGE_SIZE self.app.config["WTF_CSRF_ENABLED"] = False self.app.config["FAB_ROLES"] = { "ReadOnly": [[".*", "can_get"], [".*", "can_info"]] } self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) # Create models and insert data insert_data(self.db.session, MODEL1_DATA_SIZE) class Model1Api(ModelRestApi): datamodel = SQLAInterface(Model1) list_columns = [ "field_integer", "field_float", "field_string", "field_date", ] description_columns = { "field_integer": "Field Integer", "field_float": "Field Float", "field_string": "Field String", } self.model1api = Model1Api self.appbuilder.add_api(Model1Api) class Model1ApiFieldsInfo(Model1Api): datamodel = SQLAInterface(Model1) add_columns = [ "field_integer", "field_float", "field_string", "field_date" ] edit_columns = ["field_string", "field_integer"] self.model1apifieldsinfo = Model1ApiFieldsInfo self.appbuilder.add_api(Model1ApiFieldsInfo) class Model1FuncApi(ModelRestApi): datamodel = SQLAInterface(Model1) list_columns = [ "field_integer", "field_float", "field_string", "field_date", "full_concat", ] description_columns = { "field_integer": "Field Integer", "field_float": "Field Float", "field_string": "Field String", } self.model1funcapi = Model1Api self.appbuilder.add_api(Model1FuncApi) class Model1ApiExcludeCols(ModelRestApi): datamodel = SQLAInterface(Model1) list_exclude_columns = [ "field_integer", "field_float", "field_date" ] show_exclude_columns = list_exclude_columns edit_exclude_columns = list_exclude_columns add_exclude_columns = list_exclude_columns self.appbuilder.add_api(Model1ApiExcludeCols) class Model1ApiOrder(ModelRestApi): datamodel = SQLAInterface(Model1) base_order = ("field_integer", "desc") self.appbuilder.add_api(Model1ApiOrder) class Model1ApiRestrictedPermissions(ModelRestApi): datamodel = SQLAInterface(Model1) base_permissions = ["can_get", "can_info"] self.appbuilder.add_api(Model1ApiRestrictedPermissions) class Model1ApiFiltered(ModelRestApi): datamodel = SQLAInterface(Model1) base_filters = [ ["field_integer", FilterGreater, 2], ["field_integer", FilterSmaller, 4], ] self.appbuilder.add_api(Model1ApiFiltered) class ModelWithEnumsApi(ModelRestApi): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_api(ModelWithEnumsApi) class Model1BrowserLogin(ModelRestApi): datamodel = SQLAInterface(Model1) allow_browser_login = True self.appbuilder.add_api(Model1BrowserLogin) class ModelMMApi(ModelRestApi): datamodel = SQLAInterface(ModelMMParent) self.appbuilder.add_api(ModelMMApi) class Model1CustomValidationApi(ModelRestApi): datamodel = SQLAInterface(Model1) validators_columns = {"field_string": validate_name} self.appbuilder.add_api(Model1CustomValidationApi) class Model2Api(ModelRestApi): datamodel = SQLAInterface(Model2) list_columns = ["group"] show_columns = ["group"] self.model2api = Model2Api self.appbuilder.add_api(Model2Api) class Model2ApiFilteredRelFields(ModelRestApi): datamodel = SQLAInterface(Model2) list_columns = ["group"] show_columns = ["group"] add_query_rel_fields = { "group": [ ["field_integer", FilterGreater, 2], ["field_integer", FilterSmaller, 4], ] } edit_query_rel_fields = add_query_rel_fields self.model2apifilteredrelfields = Model2ApiFilteredRelFields self.appbuilder.add_api(Model2ApiFilteredRelFields) class Model1PermOverride(ModelRestApi): datamodel = SQLAInterface(Model1) class_permission_name = 'api' method_permission_name = { "get_list": "access", "get": "access", "put": "access", "post": "access", "delete": "access", "info": "access" } self.model1permoverride = Model1PermOverride self.appbuilder.add_api(Model1PermOverride) self.create_admin_user(self.appbuilder, USERNAME, PASSWORD) self.create_user(self.appbuilder, USERNAME_READONLY, PASSWORD_READONLY, "ReadOnly", first_name="readonly", last_name="readonly", email="*****@*****.**") def tearDown(self): self.appbuilder = None self.app = None self.db = None def test_auth_login(self): """ REST Api: Test auth login """ client = self.app.test_client() rv = self._login(client, USERNAME, PASSWORD) eq_(rv.status_code, 200) assert json.loads(rv.data.decode("utf-8")).get( API_SECURITY_ACCESS_TOKEN_KEY, False) def test_auth_login_failed(self): """ REST Api: Test auth login failed """ client = self.app.test_client() rv = self._login(client, "fail", "fail") eq_(json.loads(rv.data), {"message": "Not authorized"}) eq_(rv.status_code, 401) def test_auth_login_bad(self): """ REST Api: Test auth login bad request """ client = self.app.test_client() rv = client.post("api/v1/security/login", data="BADADATA") eq_(rv.status_code, 400) def test_auth_authorization_browser(self): """ REST Api: Test auth with browser login """ client = self.app.test_client() rv = self.browser_login(client, USERNAME, PASSWORD) # Test access with browser login uri = "api/v1/model1browserlogin/1" rv = client.get(uri) eq_(rv.status_code, 200) # Test unauthorized access with browser login uri = "api/v1/model1api/1" rv = client.get(uri) eq_(rv.status_code, 401) # Test access wihout cookie or JWT rv = self.browser_logout(client) # Test access with browser login uri = "api/v1/model1browserlogin/1" rv = client.get(uri) eq_(rv.status_code, 401) # Test access with JWT but without cookie token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1browserlogin/1" rv = self.auth_client_get(client, token, uri) eq_(rv.status_code, 200) def test_auth_authorization(self): """ REST Api: Test auth base limited authorization """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # Test unauthorized DELETE pk = 1 uri = "api/v1/model1apirestrictedpermissions/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 401) # Test unauthorized POST item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1apirestrictedpermissions/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 401) # Test authorized GET uri = "api/v1/model1apirestrictedpermissions/1" rv = self.auth_client_get(client, token, uri) eq_(rv.status_code, 200) def test_get_item(self): """ REST Api: Test get item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) for i in range(1, MODEL1_DATA_SIZE): rv = self.auth_client_get(client, token, "api/v1/model1api/{}".format(i)) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 200) self.assert_get_item(rv, data, i - 1) def assert_get_item(self, rv, data, value): eq_( data[API_RESULT_RES_KEY], { "field_date": None, "field_float": float(value), "field_integer": value, "field_string": "test{}".format(value), }, ) # test descriptions eq_(data["description_columns"], self.model1api.description_columns) # test labels eq_( data[API_LABEL_COLUMNS_RES_KEY], { "field_date": "Field Date", "field_float": "Field Float", "field_integer": "Field Integer", "field_string": "Field String", }, ) eq_(rv.status_code, 200) def test_get_item_select_cols(self): """ REST Api: Test get item with select columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) for i in range(1, MODEL1_DATA_SIZE): uri = "api/v1/model1api/{}?q=({}:!(field_integer))".format( i, API_SELECT_COLUMNS_RIS_KEY) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"field_integer": i - 1}) eq_( data[API_DESCRIPTION_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}, ) eq_(data[API_LABEL_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}) eq_(rv.status_code, 200) def test_get_item_select_meta_data(self): """ REST Api: Test get item select meta data """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) selectable_keys = [ API_DESCRIPTION_COLUMNS_RIS_KEY, API_LABEL_COLUMNS_RIS_KEY, API_SHOW_COLUMNS_RIS_KEY, API_SHOW_TITLE_RIS_KEY, ] for selectable_key in selectable_keys: argument = {API_SELECT_KEYS_RIS_KEY: [selectable_key]} uri = "api/v1/model1api/1?{}={}".format(API_URI_RIS_KEY, prison.dumps(argument)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data.keys()), 1 + 2) # always exist id, result # We assume that rison meta key equals result meta key assert selectable_key in data def test_get_item_excluded_cols(self): """ REST Api: Test get item with excluded columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 rv = self.auth_client_get(client, token, "api/v1/model1apiexcludecols/{}".format(pk)) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"field_string": "test0"}) eq_(rv.status_code, 200) def test_get_item_not_found(self): """ REST Api: Test get item not found """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = MODEL1_DATA_SIZE + 1 rv = self.auth_client_get(client, token, "api/v1/model1api/{}".format(pk)) eq_(rv.status_code, 404) def test_get_item_base_filters(self): """ REST Api: Test get item with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # We can't get a base filtered item pk = 1 rv = self.auth_client_get(client, token, "api/v1/model1apifiltered/{}".format(pk)) eq_(rv.status_code, 404) # This one is ok pk=4 field_integer=3 2>3<4 pk = 4 rv = self.auth_client_get(client, token, "api/v1/model1apifiltered/{}".format(pk)) eq_(rv.status_code, 200) def test_get_item_1m_field(self): """ REST Api: Test get item with 1-N related field """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # We can't get a base filtered item pk = 1 rv = self.auth_client_get(client, token, "api/v1/model2api/{}".format(pk)) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 200) expected_rel_field = { "group": { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", "id": 1, } } eq_(data[API_RESULT_RES_KEY], expected_rel_field) def test_get_item_mm_field(self): """ REST Api: Test get item with N-N related field """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # We can't get a base filtered item pk = 1 rv = self.auth_client_get(client, token, "api/v1/modelmmapi/{}".format(pk)) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 200) expected_rel_field = [ { "field_string": "1", "id": 1 }, { "field_string": "2", "id": 2 }, { "field_string": "3", "id": 3 }, ] eq_(data[API_RESULT_RES_KEY]["children"], expected_rel_field) def test_get_list(self): """ REST Api: Test get list """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) rv = self.auth_client_get(client, token, "api/v1/model1api/") data = json.loads(rv.data.decode("utf-8")) # Tests count property eq_(data["count"], MODEL1_DATA_SIZE) # Tests data result default page size eq_(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) for i in range(1, self.model1api.page_size): self.assert_get_list(rv, data[API_RESULT_RES_KEY][i - 1], i - 1) @staticmethod def assert_get_list(rv, data, value): eq_( data, { "field_date": None, "field_float": float(value), "field_integer": value, "field_string": "test{}".format(value), }, ) eq_(rv.status_code, 200) def test_get_list_order(self): """ REST Api: Test get list order params """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test string order asc arguments = {"order_column": "field_integer", "order_direction": "asc"} uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", }, ) eq_(rv.status_code, 200) # test string order desc arguments = { "order_column": "field_integer", "order_direction": "desc" } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(MODEL1_DATA_SIZE - 1), "field_integer": MODEL1_DATA_SIZE - 1, "field_string": "test{}".format(MODEL1_DATA_SIZE - 1), }, ) eq_(rv.status_code, 200) def test_get_list_base_order(self): """ REST Api: Test get list with base order """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test string order asc rv = self.auth_client_get(client, token, "api/v1/model1apiorder/") data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(MODEL1_DATA_SIZE - 1), "field_integer": MODEL1_DATA_SIZE - 1, "field_string": "test{}".format(MODEL1_DATA_SIZE - 1), }, ) # Test override arguments = {"order_column": "field_integer", "order_direction": "asc"} uri = "api/v1/model1apiorder/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", }, ) def test_get_list_page(self): """ REST Api: Test get list page params """ page_size = 5 client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test page zero arguments = { "page_size": page_size, "page": 0, "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": 0.0, "field_integer": 0, "field_string": "test0", }, ) eq_(rv.status_code, 200) eq_(len(data[API_RESULT_RES_KEY]), page_size) # test page one arguments = { "page_size": page_size, "page": 1, "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(page_size), "field_integer": page_size, "field_string": "test{}".format(page_size), }, ) eq_(rv.status_code, 200) eq_(len(data[API_RESULT_RES_KEY]), page_size) def test_get_list_max_page_size(self): """ REST Api: Test get list max page size config setting """ page_size = 100 # Max is globally set to MAX_PAGE_SIZE client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # test page zero arguments = { "page_size": page_size, "page": 0, "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) print("URI {}".format(uri)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data[API_RESULT_RES_KEY]), MAX_PAGE_SIZE) def test_get_list_filters(self): """ REST Api: Test get list filter params """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) filter_value = 5 # test string order asc arguments = { API_FILTERS_RIS_KEY: [{ "col": "field_integer", "opr": "gt", "value": filter_value }], "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_( data[API_RESULT_RES_KEY][0], { "field_date": None, "field_float": float(filter_value + 1), "field_integer": filter_value + 1, "field_string": "test{}".format(filter_value + 1), }, ) eq_(rv.status_code, 200) def test_get_list_select_cols(self): """ REST Api: Test get list with selected columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) argument = { API_SELECT_COLUMNS_RIS_KEY: ["field_integer"], "order_column": "field_integer", "order_direction": "asc", } uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(argument)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY][0], {"field_integer": 0}) eq_(data[API_LABEL_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}) eq_(data[API_DESCRIPTION_COLUMNS_RES_KEY], {"field_integer": "Field Integer"}) eq_(data[API_LIST_COLUMNS_RES_KEY], ["field_integer"]) eq_(rv.status_code, 200) def test_get_list_select_meta_data(self): """ REST Api: Test get list select meta data """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) selectable_keys = [ API_DESCRIPTION_COLUMNS_RIS_KEY, API_LABEL_COLUMNS_RIS_KEY, API_ORDER_COLUMNS_RIS_KEY, API_LIST_COLUMNS_RIS_KEY, API_LIST_TITLE_RIS_KEY, ] for selectable_key in selectable_keys: argument = {API_SELECT_KEYS_RIS_KEY: [selectable_key]} uri = "api/v1/model1api/?{}={}".format(API_URI_RIS_KEY, prison.dumps(argument)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data.keys()), 1 + 3) # always exist count, ids, result # We assume that rison meta key equals result meta key assert selectable_key in data def test_get_list_exclude_cols(self): """ REST Api: Test get list with excluded columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1apiexcludecols/" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY][0], {"field_string": "test0"}) def test_get_list_base_filters(self): """ REST Api: Test get list with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) arguments = { "order_column": "field_integer", "order_direction": "desc" } uri = "api/v1/model1apifiltered/?{}={}".format(API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_result = [{ "field_date": None, "field_float": 3.0, "field_integer": 3, "field_string": "test3", }] eq_(data[API_RESULT_RES_KEY], expected_result) def test_info_filters(self): """ REST Api: Test info filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1api/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_filters = { "field_date": [ { "name": "Equal to", "operator": "eq" }, { "name": "Greater than", "operator": "gt" }, { "name": "Smaller than", "operator": "lt" }, { "name": "Not Equal to", "operator": "neq" }, ], "field_float": [ { "name": "Equal to", "operator": "eq" }, { "name": "Greater than", "operator": "gt" }, { "name": "Smaller than", "operator": "lt" }, { "name": "Not Equal to", "operator": "neq" }, ], "field_integer": [ { "name": "Equal to", "operator": "eq" }, { "name": "Greater than", "operator": "gt" }, { "name": "Smaller than", "operator": "lt" }, { "name": "Not Equal to", "operator": "neq" }, ], "field_string": [ { "name": "Starts with", "operator": "sw" }, { "name": "Ends with", "operator": "ew" }, { "name": "Contains", "operator": "ct" }, { "name": "Equal to", "operator": "eq" }, { "name": "Not Starts with", "operator": "nsw" }, { "name": "Not Ends with", "operator": "new" }, { "name": "Not Contains", "operator": "nct" }, { "name": "Not Equal to", "operator": "neq" }, ], } eq_(data["filters"], expected_filters) def test_info_fields(self): """ REST Api: Test info fields (add, edit) """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1apifieldsinfo/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expect_add_fields = [ { "description": "Field Integer", "label": "Field Integer", "name": "field_integer", "required": False, "unique": False, "type": "Integer", }, { "description": "Field Float", "label": "Field Float", "name": "field_float", "required": False, "unique": False, "type": "Float", }, { "description": "Field String", "label": "Field String", "name": "field_string", "required": True, "unique": True, "type": "String", "validate": ["<Length(min=None, max=50, equal=None, error=None)>"], }, { "description": "", "label": "Field Date", "name": "field_date", "required": False, "unique": False, "type": "Date", }, ] expect_edit_fields = list() for edit_col in self.model1apifieldsinfo.edit_columns: for item in expect_add_fields: if item["name"] == edit_col: expect_edit_fields.append(item) eq_(data[API_ADD_COLUMNS_RES_KEY], expect_add_fields) eq_(data[API_EDIT_COLUMNS_RES_KEY], expect_edit_fields) def test_info_fields_rel_field(self): """ REST Api: Test info fields with related fields """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model2api/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_rel_add_field = { "count": MODEL2_DATA_SIZE, "description": "", "label": "Group", "name": "group", "required": True, "unique": False, "type": "Related", "values": [], } for i in range(self.model2api.page_size): expected_rel_add_field["values"].append({ "id": i + 1, "value": "test{}".format(i) }) for rel_field in data[API_ADD_COLUMNS_RES_KEY]: if rel_field["name"] == "group": eq_(rel_field, expected_rel_add_field) def test_info_fields_rel_filtered_field(self): """ REST Api: Test info fields with filtered related fields """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model2apifilteredrelfields/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_rel_add_field = { "description": "", "label": "Group", "name": "group", "required": True, "unique": False, "type": "Related", "count": 1, "values": [{ "id": 4, "value": "test3" }], } for rel_field in data[API_ADD_COLUMNS_RES_KEY]: if rel_field["name"] == "group": eq_(rel_field, expected_rel_add_field) for rel_field in data[API_EDIT_COLUMNS_RES_KEY]: if rel_field["name"] == "group": eq_(rel_field, expected_rel_add_field) def test_info_permissions(self): """ REST Api: Test info permissions """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1api/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_permissions = [ "can_delete", "can_get", "can_info", "can_post", "can_put", ] eq_(sorted(data[API_PERMISSIONS_RES_KEY]), expected_permissions) uri = "api/v1/model1apirestrictedpermissions/_info" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) expected_permissions = ["can_get", "can_info"] eq_(sorted(data[API_PERMISSIONS_RES_KEY]), expected_permissions) def test_info_select_meta_data(self): """ REST Api: Test info select meta data """ # select meta for add fields client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) selectable_keys = [ API_ADD_COLUMNS_RIS_KEY, API_EDIT_COLUMNS_RIS_KEY, API_PERMISSIONS_RIS_KEY, API_FILTERS_RIS_KEY, API_ADD_TITLE_RIS_KEY, API_EDIT_TITLE_RIS_KEY, ] for selectable_key in selectable_keys: arguments = {API_SELECT_KEYS_RIS_KEY: [selectable_key]} uri = "api/v1/model1api/_info?{}={}".format( API_URI_RIS_KEY, prison.dumps(arguments)) rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) eq_(len(data.keys()), 1) # We assume that rison meta key equals result meta key assert selectable_key in data def test_delete_item(self): """ REST Api: Test delete item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 2 uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model, None) def test_delete_item_not_found(self): """ REST Api: Test delete item not found """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = MODEL1_DATA_SIZE + 1 uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 404) def test_delete_item_base_filters(self): """ REST Api: Test delete item with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) # Try to delete a filtered item pk = 1 uri = "api/v1/model1apifiltered/{}".format(pk) rv = self.auth_client_delete(client, token, uri) eq_(rv.status_code, 404) def test_update_item(self): """ REST Api: Test update item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 3 item = dict(field_string="test_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model.field_string, "test_Put") eq_(model.field_integer, 0) eq_(model.field_float, 0.0) def test_update_custom_validation(self): """ REST Api: Test update item custom validation """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 3 item = dict(field_string="test_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1customvalidationapi/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) pk = 3 item = dict(field_string="Atest_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1customvalidationapi/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) def test_update_item_base_filters(self): """ REST Api: Test update item with base filters """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 4 item = dict(field_string="test_Put", field_integer=3, field_float=3.0) uri = "api/v1/model1apifiltered/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model.field_string, "test_Put") eq_(model.field_integer, 3) eq_(model.field_float, 3.0) # We can't update an item that is base filtered pk = 1 uri = "api/v1/model1apifiltered/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 404) def test_update_item_not_found(self): """ REST Api: Test update item not found """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = MODEL1_DATA_SIZE + 1 item = dict(field_string="test_Put", field_integer=0, field_float=0.0) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 404) def test_update_val_size(self): """ REST Api: Test update validate size """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 field_string = "a" * 51 item = dict(field_string=field_string, field_integer=11, field_float=11.0) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Longer than maximum length 50.") def test_update_mm_field(self): """ REST Api: Test update m-m field """ model = ModelMMChild() model.field_string = "update_m,m" self.appbuilder.get_session.add(model) self.appbuilder.get_session.commit() client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 item = dict(children=[4]) uri = "api/v1/modelmmapi/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) eq_(data[API_RESULT_RES_KEY], {"children": [4], "field_string": "0"}) def test_update_item_val_type(self): """ REST Api: Test update validate type """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer="test{}".format(MODEL1_DATA_SIZE + 1), field_float=11.0, ) uri = "api/v1/model1api/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_integer"][0], "Not a valid integer.") item = dict(field_string=11, field_integer=11, field_float=11.0) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Not a valid string.") def test_update_item_excluded_cols(self): """ REST Api: Test update item with excluded cols """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) pk = 1 item = dict(field_string="test_Put", field_integer=1000) uri = "api/v1/model1apiexcludecols/{}".format(pk) rv = self.auth_client_put(client, token, uri, item) eq_(rv.status_code, 200) model = self.db.session.query(Model1).get(pk) eq_(model.field_integer, 0) eq_(model.field_float, 0.0) eq_(model.field_date, None) def test_create_item(self): """ REST Api: Test create item """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1api/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 201) eq_(data[API_RESULT_RES_KEY], item) model = (self.db.session.query(Model1).filter_by( field_string="test{}".format(MODEL1_DATA_SIZE + 1)).first()) eq_(model.field_string, "test{}".format(MODEL1_DATA_SIZE + 1)) eq_(model.field_integer, MODEL1_DATA_SIZE + 1) eq_(model.field_float, float(MODEL1_DATA_SIZE + 1)) def test_create_item_custom_validation(self): """ REST Api: Test create item custom validation """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1customvalidationapi/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 422) eq_(data, {"message": {"field_string": ["Name must start with an A"]}}) item = dict( field_string="A{}".format(MODEL1_DATA_SIZE + 1), field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) uri = "api/v1/model1customvalidationapi/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 201) def test_create_item_val_size(self): """ REST Api: Test create validate size """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) field_string = "a" * 51 item = dict( field_string=field_string, field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), ) uri = "api/v1/model1api/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Longer than maximum length 50.") def test_create_item_val_type(self): """ REST Api: Test create validate type """ # Test integer as string client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE), field_integer="test{}".format(MODEL1_DATA_SIZE), field_float=float(MODEL1_DATA_SIZE), ) uri = "api/v1/model1api/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_integer"][0], "Not a valid integer.") # Test string as integer item = dict( field_string=MODEL1_DATA_SIZE, field_integer=MODEL1_DATA_SIZE, field_float=float(MODEL1_DATA_SIZE), ) rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 422) data = json.loads(rv.data.decode("utf-8")) eq_(data["message"]["field_string"][0], "Not a valid string.") def test_create_item_excluded_cols(self): """ REST Api: Test create with excluded columns """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict(field_string="test{}".format(MODEL1_DATA_SIZE + 1)) uri = "api/v1/model1apiexcludecols/" rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 201) item = dict( field_string="test{}".format(MODEL1_DATA_SIZE + 2), field_integer=MODEL1_DATA_SIZE + 2, ) rv = self.auth_client_post(client, token, uri, item) eq_(rv.status_code, 201) model = (self.db.session.query(Model1).filter_by( field_string="test{}".format(MODEL1_DATA_SIZE + 1)).first()) eq_(model.field_integer, None) eq_(model.field_float, None) eq_(model.field_date, None) def test_create_item_with_enum(self): """ REST Api: Test create item with enum """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) item = dict(enum2="e1") uri = "api/v1/modelwithenumsapi/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) eq_(rv.status_code, 201) model = self.db.session.query(ModelWithEnums).get(data["id"]) eq_(model.enum2, TmpEnum.e1) def test_get_list_col_function(self): """ REST Api: Test get list of objects with columns as functions """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/model1funcapi/" rv = self.auth_client_get(client, token, uri) data = json.loads(rv.data.decode("utf-8")) # Tests count property eq_(data["count"], MODEL1_DATA_SIZE) # Tests data result default page size eq_(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) for i in range(1, self.model1api.page_size): item = data[API_RESULT_RES_KEY][i - 1] eq_( item["full_concat"], "{}.{}.{}.{}".format("test" + str(i - 1), i - 1, float(i - 1), None), ) def test_openapi(self): """ REST Api: Test OpenAPI spec """ client = self.app.test_client() token = self.login(client, USERNAME, PASSWORD) uri = "api/v1/_openapi" rv = self.auth_client_get(client, token, uri) eq_(rv.status_code, 200)
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder import ModelRestApi self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False self.app.config["FAB_API_MAX_PAGE_SIZE"] = MAX_PAGE_SIZE self.app.config["WTF_CSRF_ENABLED"] = False self.app.config["FAB_ROLES"] = { "ReadOnly": [[".*", "can_get"], [".*", "can_info"]] } self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) # Create models and insert data insert_data(self.db.session, MODEL1_DATA_SIZE) class Model1Api(ModelRestApi): datamodel = SQLAInterface(Model1) list_columns = [ "field_integer", "field_float", "field_string", "field_date", ] description_columns = { "field_integer": "Field Integer", "field_float": "Field Float", "field_string": "Field String", } self.model1api = Model1Api self.appbuilder.add_api(Model1Api) class Model1ApiFieldsInfo(Model1Api): datamodel = SQLAInterface(Model1) add_columns = [ "field_integer", "field_float", "field_string", "field_date" ] edit_columns = ["field_string", "field_integer"] self.model1apifieldsinfo = Model1ApiFieldsInfo self.appbuilder.add_api(Model1ApiFieldsInfo) class Model1FuncApi(ModelRestApi): datamodel = SQLAInterface(Model1) list_columns = [ "field_integer", "field_float", "field_string", "field_date", "full_concat", ] description_columns = { "field_integer": "Field Integer", "field_float": "Field Float", "field_string": "Field String", } self.model1funcapi = Model1Api self.appbuilder.add_api(Model1FuncApi) class Model1ApiExcludeCols(ModelRestApi): datamodel = SQLAInterface(Model1) list_exclude_columns = [ "field_integer", "field_float", "field_date" ] show_exclude_columns = list_exclude_columns edit_exclude_columns = list_exclude_columns add_exclude_columns = list_exclude_columns self.appbuilder.add_api(Model1ApiExcludeCols) class Model1ApiOrder(ModelRestApi): datamodel = SQLAInterface(Model1) base_order = ("field_integer", "desc") self.appbuilder.add_api(Model1ApiOrder) class Model1ApiRestrictedPermissions(ModelRestApi): datamodel = SQLAInterface(Model1) base_permissions = ["can_get", "can_info"] self.appbuilder.add_api(Model1ApiRestrictedPermissions) class Model1ApiFiltered(ModelRestApi): datamodel = SQLAInterface(Model1) base_filters = [ ["field_integer", FilterGreater, 2], ["field_integer", FilterSmaller, 4], ] self.appbuilder.add_api(Model1ApiFiltered) class ModelWithEnumsApi(ModelRestApi): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_api(ModelWithEnumsApi) class Model1BrowserLogin(ModelRestApi): datamodel = SQLAInterface(Model1) allow_browser_login = True self.appbuilder.add_api(Model1BrowserLogin) class ModelMMApi(ModelRestApi): datamodel = SQLAInterface(ModelMMParent) self.appbuilder.add_api(ModelMMApi) class Model1CustomValidationApi(ModelRestApi): datamodel = SQLAInterface(Model1) validators_columns = {"field_string": validate_name} self.appbuilder.add_api(Model1CustomValidationApi) class Model2Api(ModelRestApi): datamodel = SQLAInterface(Model2) list_columns = ["group"] show_columns = ["group"] self.model2api = Model2Api self.appbuilder.add_api(Model2Api) class Model2ApiFilteredRelFields(ModelRestApi): datamodel = SQLAInterface(Model2) list_columns = ["group"] show_columns = ["group"] add_query_rel_fields = { "group": [ ["field_integer", FilterGreater, 2], ["field_integer", FilterSmaller, 4], ] } edit_query_rel_fields = add_query_rel_fields self.model2apifilteredrelfields = Model2ApiFilteredRelFields self.appbuilder.add_api(Model2ApiFilteredRelFields) class Model1PermOverride(ModelRestApi): datamodel = SQLAInterface(Model1) class_permission_name = 'api' method_permission_name = { "get_list": "access", "get": "access", "put": "access", "post": "access", "delete": "access", "info": "access" } self.model1permoverride = Model1PermOverride self.appbuilder.add_api(Model1PermOverride) self.create_admin_user(self.appbuilder, USERNAME, PASSWORD) self.create_user(self.appbuilder, USERNAME_READONLY, PASSWORD_READONLY, "ReadOnly", first_name="readonly", last_name="readonly", email="*****@*****.**")
import logging from flask import Flask from flask_appbuilder import SQLA, AppBuilder """ Logging configuration """ logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(message)s") logging.getLogger().setLevel(logging.DEBUG) app = Flask(__name__) app.config.from_object("config") db = SQLA(app) appbuilder = AppBuilder(app, db.session) """ from sqlalchemy.engine import Engine from sqlalchemy import event #Only include this for SQLLite constraints @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() """ from app import views
class FlaskTestCase(FABTestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.app.config.from_object("flask_appbuilder.tests.config_api") logging.basicConfig(level=logging.ERROR) self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ["can_list", "can_show"] list_columns = ["UID", "C", "CMD", "TIME"] search_columns = ["UID", "C", "CMD"] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "test1"]] } add_form_query_rel_fields = { "group": [["field_string", FilterEqual, "test0"]] } class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model3View(ModelView): datamodel = SQLAInterface(Model3) list_columns = ["pk1", "pk2", "field_string"] add_columns = ["pk1", "pk2", "field_string"] edit_columns = ["pk1", "pk2", "field_string"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model3CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model3) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) def post_add_redirect(self): return redirect("/") def post_edit_redirect(self): return redirect("/") def post_delete_redirect(self): return redirect("/") class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_string", FilterStartsWith, "test2"]] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_string"] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" definitions = [{ "group": "field_string", "series": [( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", )], }] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" list_title = "" definitions = [{ "group": "field_string", "series": ["field_integer", "field_float"] }] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_date"] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ["field_string"] show_columns = ["field_string"] formatters_columns = {"field_string": lambda x: "FORMATTED_STRING"} class ModelWithEnumsView(ModelView): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view(Model1ViewWithRedirects, "Model1ViewWithRedirects", category="Model1") self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category="Model1") self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category="Model1") self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category="Model1") self.appbuilder.add_view(Model1FormattedView, "Model1FormattedView", category="Model1FormattedView") self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(Model3View, "Model3") self.appbuilder.add_view(Model3CompactView, "Model3Compact") self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums") self.appbuilder.add_view(PSView, "Generic DS PS View", category="PSView") role_admin = self.appbuilder.sm.find_role("Admin") self.appbuilder.sm.add_user("admin", "admin", "user", "*****@*****.**", role_admin, "general") role_read_only = self.appbuilder.sm.find_role("ReadOnly") self.appbuilder.sm.add_user( USERNAME_READONLY, "readonly", "readonly", "*****@*****.**", role_read_only, PASSWORD_READONLY, ) def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") def test_fab_views(self): """ Test views creation and registration """ self.assertEqual(len(self.appbuilder.baseviews), 35) def test_back(self): """ Test Back functionality """ with self.app.test_client() as c: self.browser_login(c, USERNAME_ADMIN, PASSWORD_ADMIN) c.get("/model1view/list/?_flt_0_field_string=f") c.get("/model2view/list/") c.get("/back", follow_redirects=True) assert request.args["_flt_0_field_string"] == "f" assert "/model1view/list/" == request.path def test_model_creation(self): """ Test Model creation """ from sqlalchemy.engine.reflection import Inspector engine = self.db.session.get_bind(mapper=None, clause=None) inspector = Inspector.from_engine(engine) # Check if tables exist self.assertIn("model1", inspector.get_table_names()) self.assertIn("model2", inspector.get_table_names()) self.assertIn("model3", inspector.get_table_names()) self.assertIn("model_with_enums", inspector.get_table_names()) def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get("/") data = rv.data.decode("utf-8") self.assertIn(DEFAULT_INDEX_STRING, data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get("/model1view/list/") self.assertEqual(rv.status_code, 302) rv = client.get("/model2view/list/") self.assertEqual(rv.status_code, 302) # Login and list with admin self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/model1view/list/") self.assertEqual(rv.status_code, 200) rv = client.get("/model2view/list/") self.assertEqual(rv.status_code, 200) # Logout and and try to list self.browser_logout(client) rv = client.get("/model1view/list/") self.assertEqual(rv.status_code, 302) rv = client.get("/model2view/list/") self.assertEqual(rv.status_code, 302) # Invalid Login rv = self.browser_login(client, USERNAME_ADMIN, "wrong_password") data = rv.data.decode("utf-8") self.assertIn(INVALID_LOGIN_STRING, data) def test_auth_builtin_roles(self): """ Test Security builtin roles readonly """ client = self.app.test_client() self.browser_login(client, USERNAME_READONLY, PASSWORD_READONLY) # Test authorized GET rv = client.get("/model1view/list/") self.assertEqual(rv.status_code, 200) # Test authorized SHOW rv = client.get("/model1view/show/1") self.assertEqual(rv.status_code, 200) # Test unauthorized EDIT rv = client.get("/model1view/edit/1") self.assertEqual(rv.status_code, 302) # Test unauthorized DELETE rv = client.get("/model1view/delete/1") self.assertEqual(rv.status_code, 302) def test_sec_reset_password(self): """ Test Security reset password """ client = self.app.test_client() # Try Reset My password rv = client.get("/users/action/resetmypassword/1", follow_redirects=True) # Werkzeug update to 0.15.X sends this action to wrong redirect # Old test was: # data = rv.data.decode("utf-8") # ok_(ACCESS_IS_DENIED in data) self.assertEqual(rv.status_code, 404) # Reset My password rv = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/users/action/resetmypassword/1", follow_redirects=True) data = rv.data.decode("utf-8") self.assertIn("Reset Password Form", data) rv = client.post( "/resetmypassword/form", data=dict(password="******", conf_password="******"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) self.browser_logout(client) self.browser_login(client, USERNAME_ADMIN, "password") rv = client.post( "/resetmypassword/form", data=dict(password=PASSWORD_ADMIN, conf_password=PASSWORD_ADMIN), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) # Reset Password Admin rv = client.get("/users/action/resetpasswords/1", follow_redirects=True) data = rv.data.decode("utf-8") self.assertIn("Reset Password Form", data) rv = client.post( "/resetmypassword/form", data=dict(password=PASSWORD_ADMIN, conf_password=PASSWORD_ADMIN), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/psview/list", follow_redirects=True) self.assertEqual(rv.status_code, 200) def test_model_crud_add(self): """ Test ModelView CRUD Add """ client = self.app.test_client() rv = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) field_string = f"test{MODEL1_DATA_SIZE+1}" rv = client.post( "/model1view/add", data=dict( field_string=field_string, field_integer=f"{MODEL1_DATA_SIZE}", field_float=f"{float(MODEL1_DATA_SIZE)}", field_date="2014-01-01", ), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model = (self.db.session.query(Model1).filter_by( field_string=field_string).one_or_none()) self.assertEqual(model.field_string, field_string) self.assertEqual(model.field_integer, MODEL1_DATA_SIZE) # Revert data changes self.appbuilder.get_session.delete(model) self.appbuilder.get_session.commit() def test_model_crud_edit(self): """ Test ModelView CRUD Edit """ client = self.app.test_client() rv = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model = (self.appbuilder.get_session.query(Model1).filter_by( field_string="test0").one_or_none()) pk = model.id rv = client.post( f"/model1view/edit/{pk}", data=dict(field_string="test_edit", field_integer="200"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model = self.db.session.query(Model1).filter_by(id=pk).one_or_none() self.assertEqual(model.field_string, "test_edit") self.assertEqual(model.field_integer, 200) # Revert data changes insert_model1(self.appbuilder.get_session, i=pk - 1) def test_model_crud_delete(self): """ Test Model CRUD delete """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model = (self.appbuilder.get_session.query(Model2).filter_by( field_string="test0").one_or_none()) pk = model.id rv = client.get(f"/model2view/delete/{pk}", follow_redirects=True) self.assertEqual(rv.status_code, 200) model = self.db.session.query(Model2).get(pk) self.assertEqual(model, None) # Revert data changes insert_model2(self.appbuilder.get_session, i=0) def test_model_delete_integrity(self): """ Test Model CRUD delete integrity validation """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model1 = (self.appbuilder.get_session.query(Model1).filter_by( field_string="test1").one_or_none()) pk = model1.id rv = client.get(f"/model1view/delete/{pk}", follow_redirects=True) self.assertEqual(rv.status_code, 200) model = self.db.session.query(Model1).filter_by(id=pk).one_or_none() self.assertNotEqual(model, None) def test_model_crud_composite_pk(self): """ MVC CRUD generic-alter datasource where model has composite primary keys """ try: from urllib import quote except Exception: from urllib.parse import quote client = self.app.test_client() rv = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.post( "/model3view/add", data=dict(pk1="1", pk2=datetime.datetime(2017, 1, 1), field_string="foo2"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model = (self.appbuilder.get_session.query(Model3).filter_by( pk1="1").one_or_none()) self.assertEqual(model.pk1, 1) self.assertEqual(model.pk2, datetime.datetime(2017, 1, 1)) self.assertEqual(model.field_string, "foo2") pk = '[1, {"_type": "datetime", "value": "2017-01-01T00:00:00.000000"}]' rv = client.get(f"/model3view/show/{quote(pk)}", follow_redirects=True) self.assertEqual(rv.status_code, 200) rv = client.post( "/model3view/edit/" + quote(pk), data=dict(pk1="2", pk2="2017-02-02 00:00:00", field_string="bar"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model = (self.appbuilder.get_session.query(Model3).filter_by( pk1="2", pk2="2017-02-02 00:00:00").one_or_none()) self.assertEqual(model.pk1, 2) self.assertEqual(model.pk2, datetime.datetime(2017, 2, 2)) self.assertEqual(model.field_string, "bar") pk = '[2, {"_type": "datetime", "value": "2017-02-02T00:00:00.000000"}]' rv = client.get("/model3view/delete/" + quote(pk), follow_redirects=True) self.assertEqual(rv.status_code, 200) model = self.db.session.query(Model3).filter_by(pk1=2).one_or_none() self.assertEqual(model, None) def test_model_crud_add_with_enum(self): """ Test Model add for Model with Enum Columns """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) data = {"enum1": "e3", "enum2": "e3"} rv = client.post("/modelwithenumsview/add", data=data, follow_redirects=True) self.assertEqual(rv.status_code, 200) model = (self.appbuilder.get_session.query(ModelWithEnums).filter_by( enum1="e3").one_or_none()) self.assertIsNotNone(model) self.assertEqual(model.enum2, TmpEnum.e3) # Revert data changes model = (self.appbuilder.get_session.query(ModelWithEnums).filter_by( enum1="e3").one_or_none()) self.appbuilder.get_session.delete(model) self.appbuilder.get_session.commit() def test_model_crud_edit_with_enum(self): """ Test Model edit for Model with Enum Columns """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) data = {"enum1": "e3", "enum2": "e3"} pk = 1 rv = client.post(f"/modelwithenumsview/edit/{pk}", data=data, follow_redirects=True) self.assertEqual(rv.status_code, 200) model = (self.appbuilder.get_session.query(ModelWithEnums).filter_by( enum1="e3").one_or_none()) self.assertIsNotNone(model) self.assertEqual(model.enum2, TmpEnum.e3) # Revert data changes insert_model_with_enums(self.appbuilder.get_session, i=pk - 1) def test_formatted_cols(self): """ Test ModelView's formatters_columns """ client = self.app.test_client() rv = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/model1formattedview/list/") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("FORMATTED_STRING", data) rv = client.get("/model1formattedview/show/1") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("FORMATTED_STRING", data) def test_modelview_add_redirects(self): """ Test ModelView redirects after add """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.post("/model1viewwithredirects/add", data=dict(field_string="test_redirect")) self.assertEqual(rv.status_code, 302) self.assertEqual("http://localhost/", rv.headers["Location"]) # Revert data changes model1 = (self.appbuilder.get_session.query(Model1).filter_by( field_string="test_redirect").one_or_none()) self.appbuilder.get_session.delete(model1) self.appbuilder.get_session.commit() def test_modelview_edit_redirects(self): """ Test ModelView redirects after edit """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model_id = (self.db.session.query(Model1).filter_by( field_string="test0").one_or_none().id) rv = client.post( f"/model1viewwithredirects/edit/{model_id}", data=dict(field_string="test_redirect", field_integer="200"), ) self.assertEqual(rv.status_code, 302) self.assertEqual("http://localhost/", rv.headers["Location"]) # Revert data changes insert_model1(self.appbuilder.get_session, i=model_id - 1) def test_modelview_delete_redirects(self): """ Test ModelView redirects after delete """ client = self.app.test_client() rv = self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model_id = (self.db.session.query(Model1).filter_by( field_string="test0").first().id) rv = client.get(f"/model1viewwithredirects/delete/{model_id}") self.assertEqual(rv.status_code, 302) self.assertEqual("http://localhost/", rv.headers["Location"]) # Revert data changes insert_model1(self.appbuilder.get_session, i=model_id - 1) def test_add_excluded_cols(self): """ Test add_exclude_columns """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/model22view/add") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("field_string", data) self.assertIn("field_integer", data) self.assertIn("field_float", data) self.assertIn("field_date", data) self.assertNotIn("excluded_string", data) def test_edit_excluded_cols(self): """ Test edit_exclude_columns """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model = (self.appbuilder.get_session.query(Model2).filter_by( field_string="test0").one_or_none()) rv = client.get(f"/model22view/edit/{model.id}") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("field_string", data) self.assertIn("field_integer", data) self.assertIn("field_float", data) self.assertIn("field_date", data) self.assertNotIn("excluded_string", data) def test_show_excluded_cols(self): """ Test show_exclude_columns """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) model = (self.appbuilder.get_session.query(Model2).filter_by( field_string="test0").one_or_none()) rv = client.get(f"/model22view/show/{model.id}") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("Field String", data) self.assertIn("Field Integer", data) self.assertIn("Field Float", data) self.assertIn("Field Date", data) self.assertNotIn("Excluded String", data) def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # Base filter string starts with rv = client.get("/model2view/add") data = rv.data.decode("utf-8") self.assertIn("test0", data) self.assertNotIn(f"test1", data) model2 = (self.appbuilder.get_session.query(Model2).filter_by( field_string="test0").one_or_none()) # Base filter string starts with rv = client.get(f"/model2view/edit/{model2.id}") data = rv.data.decode("utf-8") self.assertIn(f"test1", data) def test_model_list_order(self): """ Test Model order on lists """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc", follow_redirects=True, ) self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("test0", data) rv = client.get( "/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc", follow_redirects=True, ) self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn(f"test{MODEL1_DATA_SIZE-1}", data) def test_model_add_unique_validation(self): """ Test Model add unique field validation """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # Test unique constraint rv = client.post( "/model1view/add", data=dict(field_string="test1", field_integer="2"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn(UNIQUE_VALIDATION_STRING, data) model = self.db.session.query(Model1).all() self.assertEqual(len(model), MODEL1_DATA_SIZE) def test_model_add_required_validation(self): """ Test Model add required fields validation """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # Test field required rv = client.post( "/model1view/add", data=dict(field_string="", field_integer="1"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn(NOTNULL_VALIDATION_STRING, data) model = self.db.session.query(Model1).all() self.assertEqual(len(model), MODEL1_DATA_SIZE) def test_model_edit_unique_validation(self): """ Test Model edit unique validation """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.post( "/model1view/edit/1", data=dict(field_string="test2", field_integer="2"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn(UNIQUE_VALIDATION_STRING, data) def test_model_edit_required_validation(self): """ Test Model edit required validation """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.post( "/model1view/edit/1", data=dict(field_string="", field_integer="2"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn(NOTNULL_VALIDATION_STRING, data) def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) models = self.db.session.query(Model1).all() self.assertEqual(len(models), MODEL1_DATA_SIZE) # Base filter string starts with rv = client.get("/model1filtered1view/list/") data = rv.data.decode("utf-8") self.assertIn("test2", data) self.assertNotIn("test0", data) # Base filter integer equals rv = client.get("/model1filtered2view/list/") data = rv.data.decode("utf-8") self.assertIn("test0", data) self.assertNotIn("test1", data) def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/model2view/list/") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("field_method_value", data) def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view with composite keys """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/model1compactview/list/") self.assertEqual(rv.status_code, 200) # test with composite pk try: from urllib import quote except Exception: from urllib.parse import quote pk = '[3, {"_type": "datetime", "value": "2017-03-03T00:00:00"}]' rv = client.post( "/model3compactview/edit/" + quote(pk), data=dict(field_string="bar"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model = self.db.session.query(Model3).first() self.assertEqual(model.field_string, "bar") rv = client.get("/model3compactview/delete/" + quote(pk), follow_redirects=True) self.assertEqual(rv.status_code, 200) model = self.db.session.query(Model3).first() self.assertEqual(model, None) # Revert data changes insert_model3(self.appbuilder.get_session) def test_edit_add_form_action_prefix_for_compactCRUDMixin(self): """ Test form_action in add, form_action in edit (CompactCRUDMixin) """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # Make sure we have something to edit. prefix = "/some-prefix" base_url = "http://localhost" + prefix session_form_action_key = "Model1CompactView__session_form_action" with client as c: expected_form_action = prefix + "/model1compactview/add/?" c.get("/model1compactview/add/", base_url=base_url) self.assertEqual(session[session_form_action_key], expected_form_action) expected_form_action = prefix + "/model1compactview/edit/1?" c.get("/model1compactview/edit/1", base_url=base_url) self.assertEqual(session[session_form_action_key], expected_form_action) def test_charts_view(self): """ Test Various Chart views """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # self.insert_data2() rv = client.get("/model2chartview/chart/") self.assertEqual(rv.status_code, 200) rv = client.get("/model2groupbychartview/chart/") self.assertEqual(rv.status_code, 200) rv = client.get("/model2directbychartview/chart/") self.assertEqual(rv.status_code, 200) # TODO: fix this rv = client.get("/model2timechartview/chart/") self.assertEqual(rv.status_code, 200) def test_master_detail_view(self): """ Test Master detail view """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) # self.insert_data2() rv = client.get("/model1masterview/list/") self.assertEqual(rv.status_code, 200) rv = client.get("/model1masterview/list/1") self.assertEqual(rv.status_code, 200) rv = client.get("/model1masterchartview/list/") self.assertEqual(rv.status_code, 200) rv = client.get("/model1masterchartview/list/1") self.assertEqual(rv.status_code, 200) def test_api_read(self): """ Testing the api/read endpoint """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.get("/model1formattedview/api/read") self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) self.assertIn("result", data) self.assertIn("pks", data) assert len(data.get("result")) > 10 def test_api_create(self): """ Testing the api/create endpoint """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) rv = client.post( "/model1view/api/create", data=dict(field_string="zzz"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model1 = (self.db.session.query(Model1).filter_by( field_string="zzz").one_or_none()) self.assertIsNotNone(model1) # Revert data changes self.appbuilder.get_session.delete(model1) self.appbuilder.get_session.commit() def test_api_update(self): """ Validate that the api update endpoint updates [only] the fields in POST data """ client = self.app.test_client() self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN) item = self.db.session.query(Model1).filter_by(id=1).one() field_integer_before = item.field_integer rv = client.put( "/model1view/api/update/1", data=dict(field_string="zzz"), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) item = self.db.session.query(Model1).filter_by(id=1).one() self.assertEqual(item.field_string, "zzz") self.assertEqual(item.field_integer, field_integer_before) # Revert data changes insert_model1(self.appbuilder.get_session, i=0) def test_class_method_permission_override(self): """ MVC: Test class method permission name override """ from flask_appbuilder import ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermOverride(ModelView): datamodel = SQLAInterface(Model1) class_permission_name = "view" method_permission_name = { "list": "access", "show": "access", "edit": "access", "add": "access", "delete": "access", "download": "access", "api_readvalues": "access", "api_column_edit": "access", "api_column_add": "access", "api_delete": "access", "api_update": "access", "api_create": "access", "api_get": "access", "api_read": "access", "api": "access", } self.model1permoverride = Model1PermOverride self.appbuilder.add_view_no_menu(Model1PermOverride) role = self.appbuilder.sm.add_role("Test") pvm = self.appbuilder.sm.find_permission_view_menu( "can_access", "view") self.appbuilder.sm.add_permission_role(role, pvm) self.appbuilder.sm.add_user("test", "test", "user", "*****@*****.**", role, "test") client = self.app.test_client() self.browser_login(client, "test", "test") rv = client.get("/model1permoverride/list/") self.assertEqual(rv.status_code, 200) rv = client.post( "/model1permoverride/add", data=dict( field_string="test1", field_integer="1", field_float="0.12", field_date="2014-01-01", ), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model = (self.db.session.query(Model1).filter_by( field_string="test1").one_or_none()) self.assertEqual(model.field_string, "test1") self.assertEqual(model.field_integer, 1) def test_method_permission_override(self): """ MVC: Test method permission name override """ from flask_appbuilder import ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermOverride(ModelView): datamodel = SQLAInterface(Model1) method_permission_name = { "list": "read", "show": "read", "edit": "write", "add": "write", "delete": "write", "download": "read", "api_readvalues": "read", "api_column_edit": "write", "api_column_add": "write", "api_delete": "write", "api_update": "write", "api_create": "write", "api_get": "read", "api_read": "read", "api": "read", } self.model1permoverride = Model1PermOverride self.appbuilder.add_view_no_menu(Model1PermOverride) role = self.appbuilder.sm.add_role("Test") pvm_read = self.appbuilder.sm.find_permission_view_menu( "can_read", "Model1PermOverride") pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride") self.appbuilder.sm.add_permission_role(role, pvm_read) self.appbuilder.sm.add_permission_role(role, pvm_write) self.appbuilder.sm.add_user("test", "test", "user", "*****@*****.**", role, "test") client = self.app.test_client() self.browser_login(client, "test", "test") rv = client.post( "/model1permoverride/add", data=dict( field_string=f"test{MODEL1_DATA_SIZE+1}", field_integer="1", field_float="0.12", field_date="2014-01-01", ), follow_redirects=True, ) self.assertEqual(rv.status_code, 200) model1 = (self.appbuilder.get_session.query(Model1).filter_by( field_string=f"test{MODEL1_DATA_SIZE+1}").one_or_none()) self.assertIsNotNone(model1) # Revert data changes self.appbuilder.get_session.delete(model1) self.appbuilder.get_session.commit() # Verify write links are on the UI rv = client.get("/model1permoverride/list/") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertIn("/model1permoverride/delete/1", data) self.assertIn("/model1permoverride/add", data) self.assertIn("/model1permoverride/edit/1", data) self.assertIn("/model1permoverride/show/1", data) # Delete write permission from Test Role role = self.appbuilder.sm.find_role("Test") pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride") self.appbuilder.sm.del_permission_role(role, pvm_write) # Unauthorized delete model1 = (self.appbuilder.get_session.query(Model1).filter_by( field_string=f"test1").one_or_none()) pk = model1.id rv = client.get(f"/model1permoverride/delete/{pk}") self.assertEqual(rv.status_code, 302) model = self.db.session.query(Model1).filter_by(id=pk).one_or_none() self.assertEqual(model.field_string, "test1") # Verify write links are gone from UI rv = client.get("/model1permoverride/list/") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") self.assertNotIn("/model1permoverride/delete/1", data) self.assertNotIn("/model1permoverride/add/", data) self.assertNotIn("/model1permoverride/edit/1", data) self.assertIn("/model1permoverride/show/1", data) # Revert data changes self.appbuilder.get_session.delete( self.appbuilder.sm.find_role("Test")) self.appbuilder.get_session.commit() def test_action_permission_override(self): """ MVC: Test action permission name override """ from flask_appbuilder import action, ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermOverride(ModelView): datamodel = SQLAInterface(Model1) method_permission_name = { "list": "read", "show": "read", "edit": "write", "add": "write", "delete": "write", "download": "read", "api_readvalues": "read", "api_column_edit": "write", "api_column_add": "write", "api_delete": "write", "api_update": "write", "api_create": "write", "api_get": "read", "api_read": "read", "api": "read", "action_one": "write", } @action("action1", "Action1", "", "fa-lock", multiple=True) def action_one(self, item): return "ACTION ONE" self.model1permoverride = Model1PermOverride self.appbuilder.add_view_no_menu(Model1PermOverride) # Add a user and login before enabling CSRF role = self.appbuilder.sm.add_role("Test") self.appbuilder.sm.add_user("test", "test", "user", "*****@*****.**", role, "test") pvm_read = self.appbuilder.sm.find_permission_view_menu( "can_read", "Model1PermOverride") pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride") self.appbuilder.sm.add_permission_role(role, pvm_read) self.appbuilder.sm.add_permission_role(role, pvm_write) client = self.app.test_client() self.browser_login(client, "test", "test") model1 = (self.appbuilder.get_session.query(Model1).filter_by( field_string="test0").one_or_none()) pk = model1.id rv = client.get(f"/model1permoverride/action/action1/{pk}") self.assertEqual(rv.status_code, 200) # Delete write permission from Test Role role = self.appbuilder.sm.find_role("Test") pvm_write = self.appbuilder.sm.find_permission_view_menu( "can_write", "Model1PermOverride") self.appbuilder.sm.del_permission_role(role, pvm_write) rv = client.get("/model1permoverride/action/action1/1") self.assertEqual(rv.status_code, 302) def test_permission_converge_compress(self): """ MVC: Test permission name converge compress """ from flask_appbuilder import ModelView from flask_appbuilder.models.sqla.interface import SQLAInterface class Model1PermConverge(ModelView): datamodel = SQLAInterface(Model1) class_permission_name = "view2" previous_class_permission_name = "Model1View" method_permission_name = { "list": "access", "show": "access", "edit": "access", "add": "access", "delete": "access", "download": "access", "api_readvalues": "access", "api_column_edit": "access", "api_column_add": "access", "api_delete": "access", "api_update": "access", "api_create": "access", "api_get": "access", "api_read": "access", "api": "access", } self.appbuilder.add_view_no_menu(Model1PermConverge) role = self.appbuilder.sm.add_role("Test") pvm = self.appbuilder.sm.find_permission_view_menu( "can_list", "Model1View") self.appbuilder.sm.add_permission_role(role, pvm) pvm = self.appbuilder.sm.find_permission_view_menu( "can_add", "Model1View") self.appbuilder.sm.add_permission_role(role, pvm) role = self.appbuilder.sm.find_role("Test") self.appbuilder.sm.add_user("test", "test", "user", "*****@*****.**", role, "test") # Remove previous class, Hack to test code change for i, baseview in enumerate(self.appbuilder.baseviews): if baseview.__class__.__name__ == "Model1View": break self.appbuilder.baseviews.pop(i) target_state_transitions = { "add": { ("Model1View", "can_edit"): {("view2", "can_access")}, ("Model1View", "can_add"): {("view2", "can_access")}, ("Model1View", "can_list"): {("view2", "can_access")}, ("Model1View", "can_download"): {("view2", "can_access")}, ("Model1View", "can_show"): {("view2", "can_access")}, ("Model1View", "can_delete"): {("view2", "can_access")}, }, "del_role_pvm": { ("Model1View", "can_show"), ("Model1View", "can_add"), ("Model1View", "can_download"), ("Model1View", "can_list"), ("Model1View", "can_edit"), ("Model1View", "can_delete"), }, "del_views": {"Model1View"}, "del_perms": set(), } state_transitions = self.appbuilder.security_converge() self.assertEqual(state_transitions, target_state_transitions) role = self.appbuilder.sm.find_role("Test") self.assertEqual(len(role.permissions), 1)
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///' self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ['can_list', 'can_show'] list_columns = ['UID', 'C', 'CMD', 'TIME'] search_columns = ['UID', 'C', 'CMD'] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] edit_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G2']] } add_form_query_rel_fields = { 'group': [['field_string', FilterEqual, 'G1']] } class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ 'field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string' ] add_exclude_columns = ['excluded_string'] edit_exclude_columns = ['excluded_string'] show_exclude_columns = ['excluded_string'] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ['field_string', 'field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) obj_id = 1 def post_add_redirect(self): return redirect( 'model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_edit_redirect(self): return redirect( 'model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) def post_delete_redirect(self): return redirect( 'model1viewwithredirects/show/{0}'.format(REDIRECT_OBJ_ID)) class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_string'] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': [(aggregate_sum, 'field_integer', aggregate_avg, 'field_integer', aggregate_count, 'field_integer')] }] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [{ 'group': 'field_string', 'series': ['field_integer', 'field_float'] }] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' group_by_columns = ['field_date'] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ['field_string'] show_columns = ['field_string'] formatters_columns = { 'field_string': lambda x: 'FORMATTED_STRING', } self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1ViewWithRedirects, "Model1ViewWithRedirects", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model1FormattedView, "Model1FormattedView", category='Model1FormattedView') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(PSView, "Generic DS PS View", category='PSView') role_admin = self.appbuilder.sm.find_role('Admin') self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general')
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.app.config.from_object("flask_appbuilder.tests.config_api") logging.basicConfig(level=logging.ERROR) self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ["can_list", "can_show"] list_columns = ["UID", "C", "CMD", "TIME"] search_columns = ["UID", "C", "CMD"] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "test1"]] } add_form_query_rel_fields = { "group": [["field_string", FilterEqual, "test0"]] } class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model3View(ModelView): datamodel = SQLAInterface(Model3) list_columns = ["pk1", "pk2", "field_string"] add_columns = ["pk1", "pk2", "field_string"] edit_columns = ["pk1", "pk2", "field_string"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model3CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model3) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) def post_add_redirect(self): return redirect("/") def post_edit_redirect(self): return redirect("/") def post_delete_redirect(self): return redirect("/") class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_string", FilterStartsWith, "test2"]] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_string"] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" definitions = [{ "group": "field_string", "series": [( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", )], }] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" list_title = "" definitions = [{ "group": "field_string", "series": ["field_integer", "field_float"] }] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_date"] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ["field_string"] show_columns = ["field_string"] formatters_columns = {"field_string": lambda x: "FORMATTED_STRING"} class ModelWithEnumsView(ModelView): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view(Model1ViewWithRedirects, "Model1ViewWithRedirects", category="Model1") self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category="Model1") self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category="Model1") self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category="Model1") self.appbuilder.add_view(Model1FormattedView, "Model1FormattedView", category="Model1FormattedView") self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(Model3View, "Model3") self.appbuilder.add_view(Model3CompactView, "Model3Compact") self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums") self.appbuilder.add_view(PSView, "Generic DS PS View", category="PSView") role_admin = self.appbuilder.sm.find_role("Admin") self.appbuilder.sm.add_user("admin", "admin", "user", "*****@*****.**", role_admin, "general") role_read_only = self.appbuilder.sm.find_role("ReadOnly") self.appbuilder.sm.add_user( USERNAME_READONLY, "readonly", "readonly", "*****@*****.**", role_read_only, PASSWORD_READONLY, )
return redirect("/superset/welcome") custom_sm = app.config.get( "CUSTOM_SECURITY_MANAGER") or SupersetSecurityManager if not issubclass(custom_sm, SupersetSecurityManager): raise Exception( """Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager, not FAB's security manager. See [4565] in UPDATING.md""") with app.app_context(): appbuilder = AppBuilder( app, db.session, base_template="superset/base.html", indexview=MyIndexView, security_manager_class=custom_sm, update_perms=False, # Run `superset init` to update FAB's perms ) security_manager = appbuilder.sm results_backend = app.config.get("RESULTS_BACKEND") # Merge user defined feature flags with default feature flags _feature_flags = app.config.get("DEFAULT_FEATURE_FLAGS") or {} _feature_flags.update(app.config.get("FEATURE_FLAGS") or {}) def get_feature_flags(): GET_FEATURE_FLAGS_FUNC = app.config.get("GET_FEATURE_FLAGS_FUNC")
def index(self): return redirect('/superset/welcome') custom_sm = app.config.get( 'CUSTOM_SECURITY_MANAGER') or SupersetSecurityManager if not issubclass(custom_sm, SupersetSecurityManager): raise Exception( """Your CUSTOM_SECURITY_MANAGER must now extend SupersetSecurityManager, not FAB's security manager. See [4565] in UPDATING.md""") appbuilder = AppBuilder( app, db.session, base_template='superset/base.html', indexview=MyIndexView, security_manager_class=custom_sm, update_perms=get_update_perms_flag(), ) security_manager = appbuilder.sm results_backend = app.config.get('RESULTS_BACKEND') # Registering sources module_datasource_map = app.config.get('DEFAULT_MODULE_DS_MAP') module_datasource_map.update(app.config.get('ADDITIONAL_MODULE_DS_MAP')) ConnectorRegistry.register_sources(module_datasource_map) # Flask-Compress if conf.get('ENABLE_FLASK_COMPRESS'):
def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_appbuilder.views import ModelView from sqlalchemy.engine import Engine from sqlalchemy import event self.app = Flask(__name__) self.app.jinja_env.undefined = jinja2.StrictUndefined self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///" self.app.config["CSRF_ENABLED"] = False self.app.config["SECRET_KEY"] = "thisismyscretkey" self.app.config["WTF_CSRF_ENABLED"] = False self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False self.app.config["FAB_ROLES"] = { "ReadOnly": [ [".*", "can_list"], [".*", "can_show"] ] } logging.basicConfig(level=logging.ERROR) @event.listens_for(Engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): # Will force sqllite contraint foreign keys cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() self.db = SQLA(self.app) self.appbuilder = AppBuilder(self.app, self.db.session) sess = PSSession() class PSView(ModelView): datamodel = GenericInterface(PSModel, sess) base_permissions = ["can_list", "can_show"] list_columns = ["UID", "C", "CMD", "TIME"] search_columns = ["UID", "C", "CMD"] class Model2View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] edit_form_query_rel_fields = { "group": [["field_string", FilterEqual, "G2"]] } add_form_query_rel_fields = {"group": [["field_string", FilterEqual, "G1"]]} class Model22View(ModelView): datamodel = SQLAInterface(Model2) list_columns = [ "field_integer", "field_float", "field_string", "field_method", "group.field_string", ] add_exclude_columns = ["excluded_string"] edit_exclude_columns = ["excluded_string"] show_exclude_columns = ["excluded_string"] class Model1View(ModelView): datamodel = SQLAInterface(Model1) related_views = [Model2View] list_columns = ["field_string", "field_file"] class Model3View(ModelView): datamodel = SQLAInterface(Model3) list_columns = ["pk1", "pk2", "field_string"] add_columns = ["pk1", "pk2", "field_string"] edit_columns = ["pk1", "pk2", "field_string"] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model1) class Model3CompactView(CompactCRUDMixin, ModelView): datamodel = SQLAInterface(Model3) class Model1ViewWithRedirects(ModelView): datamodel = SQLAInterface(Model1) obj_id = 1 def post_add_redirect(self): return redirect( "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID) ) def post_edit_redirect(self): return redirect( "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID) ) def post_delete_redirect(self): return redirect( "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID) ) class Model1Filtered1View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_string", FilterStartsWith, "a"]] class Model1MasterView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = SQLAInterface(Model1) base_filters = [["field_integer", FilterEqual, 0]] class Model2ChartView(ChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_string"] class Model2GroupByChartView(GroupByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" definitions = [ { "group": "field_string", "series": [ ( aggregate_sum, "field_integer", aggregate_avg, "field_integer", aggregate_count, "field_integer", ) ], } ] class Model2DirectByChartView(DirectByChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" list_title = "" definitions = [ {"group": "field_string", "series": ["field_integer", "field_float"]} ] class Model2TimeChartView(TimeChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" group_by_columns = ["field_date"] class Model2DirectChartView(DirectChartView): datamodel = SQLAInterface(Model2) chart_title = "Test Model1 Chart" direct_columns = {"stat1": ("group", "field_integer")} class Model1MasterChartView(MasterDetailView): datamodel = SQLAInterface(Model1) related_views = [Model2DirectByChartView] class Model1FormattedView(ModelView): datamodel = SQLAInterface(Model1) list_columns = ["field_string"] show_columns = ["field_string"] formatters_columns = {"field_string": lambda x: "FORMATTED_STRING"} class ModelWithEnumsView(ModelView): datamodel = SQLAInterface(ModelWithEnums) self.appbuilder.add_view(Model1View, "Model1", category="Model1") self.appbuilder.add_view( Model1ViewWithRedirects, "Model1ViewWithRedirects", category="Model1" ) self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1") self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1") self.appbuilder.add_view( Model1MasterChartView, "Model1MasterChart", category="Model1" ) self.appbuilder.add_view( Model1Filtered1View, "Model1Filtered1", category="Model1" ) self.appbuilder.add_view( Model1Filtered2View, "Model1Filtered2", category="Model1" ) self.appbuilder.add_view( Model1FormattedView, "Model1FormattedView", category="Model1FormattedView" ) self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model22View, "Model22") self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add") self.appbuilder.add_view(Model2ChartView, "Model2 Chart") self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") self.appbuilder.add_view(Model3View, "Model3") self.appbuilder.add_view(Model3CompactView, "Model3Compact") self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums") self.appbuilder.add_view(PSView, "Generic DS PS View", category="PSView") role_admin = self.appbuilder.sm.find_role("Admin") self.appbuilder.sm.add_user( "admin", "admin", "user", "*****@*****.**", role_admin, "general" ) role_read_only = self.appbuilder.sm.find_role("ReadOnly") self.appbuilder.sm.add_user( USERNAME_READONLY, "readonly", "readonly", "*****@*****.**", role_read_only, PASSWORD_READONLY )
class FlaskTestCase(unittest.TestCase): def setUp(self): from flask import Flask from flask_appbuilder import AppBuilder from flask_appbuilder.models.mongoengine.interface import MongoEngineInterface from flask_appbuilder import ModelView from flask_appbuilder.security.mongoengine.manager import SecurityManager self.app = Flask(__name__) self.basedir = os.path.abspath(os.path.dirname(__file__)) self.app.config['MONGODB_SETTINGS'] = {'DB': 'test'} self.app.config['CSRF_ENABLED'] = False self.app.config['SECRET_KEY'] = 'thisismyscretkey' self.app.config['WTF_CSRF_ENABLED'] = False self.db = MongoEngine(self.app) self.appbuilder = AppBuilder(self.app, security_manager_class=SecurityManager) class Model2View(ModelView): datamodel = MongoEngineInterface(Model2) list_columns = ['field_integer', 'field_float', 'field_string', 'field_method', 'group.field_string'] edit_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G2']]} add_form_query_rel_fields = {'group':[['field_string', FilterEqual, 'G1']]} class Model1View(ModelView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] list_columns = ['field_string','field_file'] class Model1CompactView(CompactCRUDMixin, ModelView): datamodel = MongoEngineInterface(Model1) class Model1Filtered1View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_string', FilterStartsWith, 'a']] class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1Filtered2View(ModelView): datamodel = MongoEngineInterface(Model1) base_filters = [['field_integer', FilterEqual, 0]] class Model2GroupByChartView(GroupByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':[(aggregate_sum,'field_integer', aggregate_avg, 'field_integer', aggregate_count,'field_integer') ] } ] class Model2DirectByChartView(DirectByChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' definitions = [ { 'group':'field_string', 'series':['field_integer','field_float'] } ] class Model2DirectChartView(DirectChartView): datamodel = MongoEngineInterface(Model2) chart_title = 'Test Model1 Chart' direct_columns = {'stat1': ('group', 'field_integer')} class Model1MasterView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2View] class Model1MasterChartView(MasterDetailView): datamodel = MongoEngineInterface(Model1) related_views = [Model2DirectByChartView] self.appbuilder.add_view(Model1View, "Model1", category='Model1') self.appbuilder.add_view(Model1CompactView, "Model1Compact", category='Model1') self.appbuilder.add_view(Model1MasterView, "Model1Master", category='Model1') self.appbuilder.add_view(Model1MasterChartView, "Model1MasterChart", category='Model1') self.appbuilder.add_view(Model1Filtered1View, "Model1Filtered1", category='Model1') self.appbuilder.add_view(Model1Filtered2View, "Model1Filtered2", category='Model1') self.appbuilder.add_view(Model2View, "Model2") self.appbuilder.add_view(Model2View, "Model2 Add", href='/model2view/add') self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart") self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart") self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart") role_admin = self.appbuilder.sm.find_role('Admin') try: self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**', role_admin, 'general') except: pass def tearDown(self): self.appbuilder = None self.app = None self.db = None log.debug("TEAR DOWN") """ --------------------------------- TEST HELPER FUNCTIONS --------------------------------- """ def login(self, client, username, password): # Login with default admin return client.post('/login/', data=dict( username=username, password=password ), follow_redirects=True) def logout(self, client): return client.get('/logout/') def insert_data(self): for x, i in zip(string.ascii_letters[:23], range(23)): model = Model1(field_string="%stest" % (x), field_integer=i) model.save() def insert_data2(self): models1 = [Model1(field_string='G1'), Model1(field_string='G2'), Model1(field_string='G3')] for model1 in models1: try: model1.save() for x, i in zip(string.ascii_letters[:10], range(10)): model = Model2(field_string="%stest" % (x), field_integer=random.randint(1, 10), field_float=random.uniform(0.0, 1.0), group=model1) year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) model.field_date = datetime.datetime(year, month, day) model.save() except Exception as e: print("ERROR {0}".format(str(e))) def clean_data(self): Model1.drop_collection() Model2.drop_collection() def test_fab_views(self): """ Test views creation and registration """ eq_(len(self.appbuilder.baseviews), 22) # current minimal views are 11 def test_index(self): """ Test initial access and index message """ client = self.app.test_client() # Check for Welcome Message rv = client.get('/') data = rv.data.decode('utf-8') ok_(DEFAULT_INDEX_STRING in data) def test_sec_login(self): """ Test Security Login, Logout, invalid login, invalid access """ client = self.app.test_client() # Try to List and Redirect to Login rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Login and list with admin self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/model1view/list/') eq_(rv.status_code, 200) rv = client.get('/model2view/list/') eq_(rv.status_code, 200) # Logout and and try to list self.logout(client) rv = client.get('/model1view/list/') eq_(rv.status_code, 302) rv = client.get('/model2view/list/') eq_(rv.status_code, 302) # Invalid Login rv = self.login(client, DEFAULT_ADMIN_USER, 'password') data = rv.data.decode('utf-8') ok_(INVALID_LOGIN_STRING in data) def test_sec_reset_password(self): """ Test Security reset password """ from flask_appbuilder.security.mongoengine.models import User client = self.app.test_client() # Try Reset My password user = User.objects.filter(**{'username': '******'})[0] rv = client.get('/users/action/resetmypassword/{0}'.format(user.id), follow_redirects=True) data = rv.data.decode('utf-8') ok_(ACCESS_IS_DENIED in data) #Reset My password rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/users/action/resetmypassword/{0}'.format(user.id), follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password='******', conf_password='******'), follow_redirects=True) eq_(rv.status_code, 200) self.logout(client) self.login(client, DEFAULT_ADMIN_USER, 'password') rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) #Reset Password Admin rv = client.get('/users/action/resetpasswords/{0}'.format(user.id), follow_redirects=True) data = rv.data.decode('utf-8') ok_("Reset Password Form" in data) rv = client.post('/resetmypassword/form', data=dict(password=DEFAULT_ADMIN_PASSWORD, conf_password=DEFAULT_ADMIN_PASSWORD), follow_redirects=True) eq_(rv.status_code, 200) def test_generic_interface(self): """ Test Generic Interface for generic-alter datasource """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.get('/psview/list') data = rv.data.decode('utf-8') def test_model_crud(self): """ Test Model add, delete, edit """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1', field_float='0.12', field_date='2014-01-01 23:10:07'), follow_redirects=True) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u'test1') eq_(model.field_integer, 1) model1 = Model1.objects(field_string='test1')[0] rv = client.post('/model1view/edit/{0}'.format(model1.id), data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) model = Model1.objects[0] eq_(model.field_string, u'test2') eq_(model.field_integer, 2) rv = client.get('/model1view/delete/{0}'.format(model.id), follow_redirects=True) eq_(rv.status_code, 200) model = Model1.objects eq_(len(model), 0) self.clean_data() def test_query_rel_fields(self): """ Test add and edit form related fields filter """ client = self.app.test_client() rv = self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() # Base filter string starts with rv = client.get('/model2view/add') data = rv.data.decode('utf-8') ok_('G1' in data) ok_('G2' not in data) model2 = Model2.objects[0] # Base filter string starts with rv = client.get('/model2view/edit/{0}'.format(model2.id)) data = rv.data.decode('utf-8') ok_('G2' in data) ok_('G1' not in data) self.clean_data() def test_model_list_order(self): """ Test Model order on lists """ self.insert_data() client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/list?_oc_Model1View=field_string&_od_Model1View=asc', follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED rv = client.post('/model1view/list?_oc_Model1View=field_string&_od_Model1View=desc', follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') # TODO # VALIDATE LIST IS ORDERED self.clean_data() def test_model_add_validation(self): """ Test Model add validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) rv = client.post('/model1view/add', data=dict(field_string='test1', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) rv = client.post('/model1view/add', data=dict(field_string='', field_integer='1'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) model = Model1.objects() eq_(len(model), 1) self.clean_data() def test_model_edit_validation(self): """ Test Model edit validations """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) client.post('/model1view/add', data=dict(field_string='test1', field_integer='1'), follow_redirects=True) model1 = Model1.objects(field_string='test1')[0] client.post('/model1view/add', data=dict(field_string='test2', field_integer='1'), follow_redirects=True) rv = client.post('/model1view/edit/{0}'.format(model1.id), data=dict(field_string='test2', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(UNIQUE_VALIDATION_STRING in data) rv = client.post('/model1view/edit/{0}'.format(model1.id), data=dict(field_string='', field_integer='2'), follow_redirects=True) eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_(NOTNULL_VALIDATION_STRING in data) self.clean_data() def test_model_base_filter(self): """ Test Model base filtered views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data() models = Model1.objects() eq_(len(models), 23) # Base filter string starts with rv = client.get('/model1filtered1view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) # Base filter integer equals rv = client.get('/model1filtered2view/list/') data = rv.data.decode('utf-8') ok_('atest' in data) ok_('btest' not in data) self.clean_data() def test_model_list_method_field(self): """ Tests a model's field has a method """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model2view/list/') eq_(rv.status_code, 200) data = rv.data.decode('utf-8') ok_('field_method_value' in data) self.clean_data() def test_compactCRUDMixin(self): """ Test CompactCRUD Mixin view """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() rv = client.get('/model1compactview/list/') eq_(rv.status_code, 200) self.clean_data() def test_charts_view(self): """ Test Various Chart views """ client = self.app.test_client() self.login(client, DEFAULT_ADMIN_USER, DEFAULT_ADMIN_PASSWORD) self.insert_data2() log.info("CHART TEST") rv = client.get('/model2groupbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2directbychartview/chart/') eq_(rv.status_code, 200) rv = client.get('/model2directchartview/chart/') #eq_(rv.status_code, 200) self.clean_data() """