Beispiel #1
0
    def setUp(self):
        clear_db_runs()
        clear_db_dags()
        self.db = SQLA(self.app)
        self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
        self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")

        log.debug("Complete setup!")
Beispiel #2
0
    def setUp(self):
        # start Flask
        self.app = Flask(__name__)
        self.app.jinja_env.undefined = jinja2.StrictUndefined
        self.app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get(
            "SQLALCHEMY_DATABASE_URI")
        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        self.app.config["AUTH_TYPE"] = AUTH_OAUTH
        self.app.config["OAUTH_PROVIDERS"] = [
        ]  # can be empty, because we dont use the external providers in tests

        # start Database
        self.db = SQLA(self.app)
Beispiel #3
0
    def setUp(self):
        from flask import Flask
        from flask_appbuilder import AppBuilder
        from flask_appbuilder.views import ModelView

        self.app = Flask(__name__)
        self.basedir = os.path.abspath(os.path.dirname(__file__))
        self.app.config.from_object("flask_appbuilder.tests.config_api")
        self.app.config["FAB_API_MAX_PAGE_SIZE"] = MAX_PAGE_SIZE

        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)

        class Model1View(ModelView):
            datamodel = SQLAInterface(Model1)

        context = self
        context._conditional_value = True

        class Model1ViewDynamic(ModelView):
            datamodel = SQLAInterface(Model1)

        self.appbuilder.add_view(Model1View, "Model1")
        self.appbuilder.add_view(
            Model1ViewDynamic,
            "Model1Dynamic",
            label="Model1 Dynamic",
            menu_cond=lambda: context._conditional_value,
        )
Beispiel #4
0
    def setUp(self):
        from flask import Flask
        from flask_appbuilder import AppBuilder

        self.app = Flask(__name__)
        self.app.config.from_object("flask_appbuilder.tests.config_api")
        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)
Beispiel #5
0
    def setUp(self):
        self.app = Flask(__name__)
        self.app.jinja_env.undefined = jinja2.StrictUndefined
        self.app.config.from_object("flask_appbuilder.tests.config_api")
        logging.basicConfig(level=logging.ERROR)

        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)
Beispiel #6
0
 def setUp(self):
     self.app = Flask(__name__)
     self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
     self.app.config['SECRET_KEY'] = 'secret_key'
     self.app.config['CSRF_ENABLED'] = False
     self.app.config['WTF_CSRF_ENABLED'] = False
     self.db = SQLA(self.app)
     self.appbuilder = AppBuilder(self.app,
                                  self.db.session,
                                  security_manager_class=AirflowSecurityManager)
     self.security_manager = self.appbuilder.sm
     self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
     self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
     role_admin = self.security_manager.find_role('Admin')
     self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**',
                                             role_admin, 'general')
     log.debug("Complete setup!")
    def setUp(self):
        # start MockLdap
        self.mockldap.start()
        self.ldapobj = self.mockldap["ldap://localhost/"]

        # start Flask
        self.app = Flask(__name__)
        self.app.jinja_env.undefined = jinja2.StrictUndefined
        self.app.config["SQLALCHEMY_DATABASE_URI"] = os.environ.get(
            "SQLALCHEMY_DATABASE_URI")
        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        self.app.config["AUTH_TYPE"] = AUTH_LDAP
        self.app.config["AUTH_LDAP_SERVER"] = "ldap://localhost/"
        self.app.config["AUTH_LDAP_UID_FIELD"] = "uid"
        self.app.config["AUTH_LDAP_FIRSTNAME_FIELD"] = "givenName"
        self.app.config["AUTH_LDAP_LASTNAME_FIELD"] = "sn"
        self.app.config["AUTH_LDAP_EMAIL_FIELD"] = "email"

        # start Database
        self.db = SQLA(self.app)
Beispiel #8
0
    def setUp(self):
        from flask import Flask
        from flask_wtf import CSRFProtect
        from flask_appbuilder import AppBuilder

        self.app = Flask(__name__)
        self.app.config.from_object("flask_appbuilder.tests.config_oauth")
        self.app.config["WTF_CSRF_ENABLED"] = True

        self.csrf = CSRFProtect(self.app)
        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)
Beispiel #9
0
    def setUp(self):
        from flask import Flask  # pylint: disable=C0415,E0401
        from flask_appbuilder import AppBuilder  # pylint: disable=C0415,E0401

        self.app = Flask(__name__)
        self.app.config.from_object('tests.test_config')
        self.app.config['PREFERRED_URL_SCHEME'] = 'https'

        self.db = SQLA(self.app)  # pylint: disable=invalid-name
        self.appbuilder = AppBuilder(
            self.app,
            self.db.session,
            security_manager_class=CustomSecurityManager)
Beispiel #10
0
def create_app(config_name=None, indexview=None, security_manager_class=None):
    global app, db, appbuilder

    app = Flask(__name__)
    app.wsgi_app = ProxyFix(app.wsgi_app)
    app.config.from_object('config')
    if config_name:
        app.config.from_pyfile(config_name)
    db = SQLA(app)
    appbuilder = AppBuilder(app,
                            db.session,
                            indexview=indexview,
                            security_manager_class=security_manager_class)

    return app
Beispiel #11
0
    def setUp(self):
        from flask import Flask
        from flask_wtf import CSRFProtect
        from flask_appbuilder import AppBuilder

        self.app = Flask(__name__)
        self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
        self.app.config["SECRET_KEY"] = "thisismyscretkey"
        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        self.app.config["WTF_CSRF_ENABLED"] = True

        self.csrf = CSRFProtect(self.app)
        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)

        self.create_admin_user(self.appbuilder, USERNAME, PASSWORD)
Beispiel #12
0
    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://'
        self.app.config['CSRF_ENABLED'] = False
        self.app.config['SECRET_KEY'] = 'thisismyscretkey'
        self.app.config['WTF_CSRF_ENABLED'] = False
        self.app.config['TESTING'] = True
        self.app.config['DEBUG'] = True

        self.db = SQLA(self.app)
        app.db = self.db
        Migrate(self.app, self.db)
        self.appbuilder = AppBuilder(self.app, self.db.session)

        with self.app.app_context():
            upgrade()
 def setUp(self):
     self.app = Flask(__name__)
     self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
     self.app.config['SECRET_KEY'] = 'secret_key'
     self.app.config['CSRF_ENABLED'] = False
     self.app.config['WTF_CSRF_ENABLED'] = False
     self.db = SQLA(self.app)
     self.appbuilder = AppBuilder(self.app,
                                  self.db.session,
                                  security_manager_class=AirflowSecurityManager)
     self.security_manager = self.appbuilder.sm
     self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
     self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
     role_admin = self.security_manager.find_role('Admin')
     self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**',
                                             role_admin, 'general')
     log.debug("Complete setup!")
Beispiel #14
0
    def setUp(self):

        self.mockldap.start()
        self.ldapobj = self.mockldap['ldap://localhost/']

        self.app = Flask(__name__)
        self.db = SQLA(self.app)

        self.app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
        self.app.config['AUTH_LDAP_UID_FIELD'] = 'cn'
        self.app.config['AUTH_LDAP_ALLOW_SELF_SIGNED'] = False
        self.app.config['AUTH_LDAP_USE_TLS'] = False
        self.app.config['AUTH_LDAP_SERVER'] = 'ldap://localhost/'
        self.app.config['AUTH_LDAP_SEARCH'] = 'ou=example,o=test'
        self.app.config['AUTH_LDAP_APPEND_DOMAIN'] = False
        self.app.config['AUTH_LDAP_FIRSTNAME_FIELD'] = None
        self.app.config['AUTH_LDAP_LASTNAME_FIELD'] = None
        self.app.config['AUTH_LDAP_EMAIL_FIELD'] = None
Beispiel #15
0
    def test_export_roles_filename(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            app = Flask("src_app")
            app.config.from_object("flask_appbuilder.tests.config_security")

            app.config[
                "SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{os.path.join(tmp_dir, 'src.db')}"
            db = SQLA(app)
            app_builder = AppBuilder(app, db.session)  # noqa: F841

            owd = os.getcwd()
            os.chdir(tmp_dir)
            cli_runner = app.test_cli_runner()
            export_result = cli_runner.invoke(export_roles)
            os.chdir(owd)

            self.assertEqual(export_result.exit_code, 0)
            self.assertGreater(
                len(glob.glob(os.path.join(tmp_dir, "roles_export_*"))), 0)
Beispiel #16
0
    def setUp(self):

        self.mockldap.start()
        self.ldapobj = self.mockldap["ldap://localhost/"]

        self.app = Flask(__name__)
        self.app.jinja_env.undefined = jinja2.StrictUndefined
        self.db = SQLA(self.app)

        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        self.app.config["AUTH_LDAP_UID_FIELD"] = "cn"
        self.app.config["AUTH_LDAP_ALLOW_SELF_SIGNED"] = False
        self.app.config["AUTH_LDAP_USE_TLS"] = False
        self.app.config["AUTH_LDAP_SERVER"] = "ldap://localhost/"
        self.app.config["AUTH_LDAP_SEARCH"] = "ou=example,o=test"
        self.app.config["AUTH_LDAP_APPEND_DOMAIN"] = False
        self.app.config["AUTH_LDAP_FIRSTNAME_FIELD"] = None
        self.app.config["AUTH_LDAP_LASTNAME_FIELD"] = None
        self.app.config["AUTH_LDAP_EMAIL_FIELD"] = None
Beispiel #17
0
    def test_import_roles(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            app = Flask("dst_app")
            app.config[
                "SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{os.path.join(tmp_dir, 'dst.db')}"
            db = SQLA(app)
            app_builder = AppBuilder(app, db.session)
            cli_runner = app.test_cli_runner()

            path = os.path.join(tmp_dir, "roles.json")

            with open(path, "w") as fd:
                fd.write(json.dumps(self.expected_roles))

            # before import roles on dst app include only Admin and Public
            self.assertEqual(len(app_builder.sm.get_all_roles()), 2)

            import_result = cli_runner.invoke(import_roles, [f"--path={path}"])
            self.assertEqual(import_result.exit_code, 0)

            resulting_roles = app_builder.sm.get_all_roles()

            for expected_role in self.expected_roles:
                match = [
                    r for r in resulting_roles
                    if r.name == expected_role["name"]
                ]
                self.assertTrue(match)
                resulting_role = match[0]

                expected_role_permission_view_menus = {
                    (pvm["permission"]["name"], pvm["view_menu"]["name"])
                    for pvm in expected_role["permissions"]
                }
                resulting_role_permission_view_menus = {
                    (pvm.permission.name, pvm.view_menu.name)
                    for pvm in resulting_role.permissions
                }
                self.assertEqual(
                    resulting_role_permission_view_menus,
                    expected_role_permission_view_menus,
                )
Beispiel #18
0
    def test_export_roles_indent(self, mock_json_dumps):
        """Test that json.dumps is called with the correct argument passed from CLI."""
        with tempfile.TemporaryDirectory() as tmp_dir:
            app = Flask("src_app")
            app.config.from_object("flask_appbuilder.tests.config_security")
            app.config[
                "SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{os.path.join(tmp_dir, 'src.db')}"
            db = SQLA(app)
            app_builder = AppBuilder(app, db.session)  # noqa: F841
            cli_runner = app.test_cli_runner()

            cli_runner.invoke(export_roles)
            mock_json_dumps.assert_called_with(ANY, indent=None)
            mock_json_dumps.reset_mock()

            example_cli_args = ["", "foo", -1, 0, 1]
            for arg in example_cli_args:
                cli_runner.invoke(export_roles, [f"--indent={arg}"])
                mock_json_dumps.assert_called_with(ANY, indent=arg)
                mock_json_dumps.reset_mock()
Beispiel #19
0
    def test_export_roles(self):
        with tempfile.TemporaryDirectory() as tmp_dir:
            app = Flask("src_app")
            app.config.from_object("flask_appbuilder.tests.config_security")
            app.config[
                "SQLALCHEMY_DATABASE_URI"] = f"sqlite:///{os.path.join(tmp_dir, 'src.db')}"
            db = SQLA(app)
            app_builder = AppBuilder(app, db.session)  # noqa: F841
            cli_runner = app.test_cli_runner()

            path = os.path.join(tmp_dir, "roles.json")

            export_result = cli_runner.invoke(export_roles, [f"--path={path}"])

            self.assertEqual(export_result.exit_code, 0)
            self.assertTrue(os.path.exists(path))

            with open(path, "r") as fd:
                resulting_roles = json.loads(fd.read())

            for expected_role in self.expected_roles:
                match = [
                    r for r in resulting_roles
                    if r["name"] == expected_role["name"]
                ]
                self.assertTrue(match)
                resulting_role = match[0]
                resulting_role_permission_view_menus = {
                    (pvm["permission"]["name"], pvm["view_menu"]["name"])
                    for pvm in resulting_role["permissions"]
                }
                expected_role_permission_view_menus = {
                    (pvm["permission"]["name"], pvm["view_menu"]["name"])
                    for pvm in expected_role["permissions"]
                }
                self.assertEqual(
                    resulting_role_permission_view_menus,
                    expected_role_permission_view_menus,
                )
    def setUp(self):
        from flask import Flask
        from flask_appbuilder import AppBuilder
        from flask_appbuilder.views import ModelView

        self.app = Flask(__name__)
        self.basedir = os.path.abspath(os.path.dirname(__file__))
        self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
        self.app.config["CSRF_ENABLED"] = False
        self.app.config["SECRET_KEY"] = "thisismyscretkey"
        self.app.config["WTF_CSRF_ENABLED"] = False
        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False

        logging.basicConfig(level=logging.ERROR)

        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)

        class Model1View(ModelView):
            datamodel = SQLAInterface(Model1)

        self.appbuilder.add_view(Model1View, "Model1")
        role_admin = self.appbuilder.sm.find_role("Admin")
        self.appbuilder.sm.add_user(DEFAULT_ADMIN_USER, "admin", "user",
                                    "*****@*****.**", role_admin,
                                    DEFAULT_ADMIN_PASSWORD)

        role_limited = self.appbuilder.sm.add_role("LimitedUser")
        pvm = self.appbuilder.sm.find_permission_view_menu(
            "menu_access", "Model1")
        self.appbuilder.sm.add_permission_role(role_limited, pvm)
        pvm = self.appbuilder.sm.find_permission_view_menu(
            "can_get", "MenuApi")
        self.appbuilder.sm.add_permission_role(role_limited, pvm)
        self.appbuilder.sm.add_user(LIMITED_USER, "user1", "user1",
                                    "*****@*****.**", role_limited,
                                    LIMITED_USER_PASSWORD)
Beispiel #21
0
def create_app(config=None, testing=False, app_name="Airflow"):
    """Create a new instance of Airflow WWW app"""
    flask_app = Flask(__name__)
    flask_app.secret_key = conf.get('webserver', 'SECRET_KEY')

    session_lifetime_days = conf.getint('webserver',
                                        'SESSION_LIFETIME_DAYS',
                                        fallback=30)
    flask_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        days=session_lifetime_days)

    flask_app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    flask_app.config['APP_NAME'] = app_name
    flask_app.config['TESTING'] = testing
    flask_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    flask_app.config['SESSION_COOKIE_HTTPONLY'] = True
    flask_app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    flask_app.config['SESSION_COOKIE_SAMESITE'] = conf.get(
        'webserver', 'COOKIE_SAMESITE')

    if config:
        flask_app.config.from_mapping(config)

    # Configure the JSON encoder used by `|tojson` filter from Flask
    flask_app.json_encoder = AirflowJsonEncoder

    CSRFProtect(flask_app)

    init_wsg_middleware(flask_app)

    db = SQLA()
    db.session = settings.Session
    db.init_app(flask_app)

    init_dagbag(flask_app)

    init_api_experimental_auth(flask_app)

    Cache(app=flask_app,
          config={
              'CACHE_TYPE': 'filesystem',
              'CACHE_DIR': '/tmp'
          })

    init_flash_views(flask_app)

    configure_logging()
    configure_manifest_files(flask_app)

    with flask_app.app_context():
        init_appbuilder(flask_app)

        init_appbuilder_views(flask_app)
        init_appbuilder_links(flask_app)
        init_plugins(flask_app)
        init_error_handlers(flask_app)
        init_api_connexion(flask_app)
        init_api_experimental(flask_app)

        sync_appbuilder_roles(flask_app)

        init_jinja_globals(flask_app)
        init_logout_timeout(flask_app)
        init_xframe_protection(flask_app)
        init_permanent_session(flask_app)

    return flask_app
Beispiel #22
0
def downgrade():
    db = SQLA()
    db.session = settings.Session
    undo_migrate_to_new_dag_permissions(db.session)
Beispiel #23
0
def upgrade():
    db = SQLA()
    db.session = settings.Session
    migrate_to_new_dag_permissions(db)
    db.session.commit()
    db.session.close()
Beispiel #24
0
from flask import Flask
from flask_appbuilder import AppBuilder, SQLA
from flask_cors import CORS
"""
 Logging configuration
"""

logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(message)s")
logging.getLogger().setLevel(logging.DEBUG)

app = Flask(__name__)
CORS(app)

app.config.from_object("config")
db = SQLA(app)
appbuilder = AppBuilder(app, db.session)
"""
from sqlalchemy.engine import Engine
from sqlalchemy import event

#Only include this for SQLLite constraints
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
    # Will force sqllite contraint foreign keys
    cursor = dbapi_connection.cursor()
    cursor.execute("PRAGMA foreign_keys=ON")
    cursor.close()
"""

from . import views
Beispiel #25
0
def create_app(config=None, session=None, testing=False, app_name="Airflow"):
    global app, appbuilder
    app = Flask(__name__)
    if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
        app.wsgi_app = ProxyFix(app.wsgi_app,
                                x_for=conf.getint("webserver",
                                                  "PROXY_FIX_X_FOR",
                                                  fallback=1),
                                x_proto=conf.getint("webserver",
                                                    "PROXY_FIX_X_PROTO",
                                                    fallback=1),
                                x_host=conf.getint("webserver",
                                                   "PROXY_FIX_X_HOST",
                                                   fallback=1),
                                x_port=conf.getint("webserver",
                                                   "PROXY_FIX_X_PORT",
                                                   fallback=1),
                                x_prefix=conf.getint("webserver",
                                                     "PROXY_FIX_X_PREFIX",
                                                     fallback=1))
    app.secret_key = conf.get('webserver', 'SECRET_KEY')

    session_lifetime_days = conf.getint('webserver',
                                        'SESSION_LIFETIME_DAYS',
                                        fallback=30)
    app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        days=session_lifetime_days)

    app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    app.config['APP_NAME'] = app_name
    app.config['TESTING'] = testing
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    app.config['SESSION_COOKIE_HTTPONLY'] = True
    app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver',
                                                     'COOKIE_SAMESITE')

    if config:
        app.config.from_mapping(config)

    # Configure the JSON encoder used by `|tojson` filter from Flask
    app.json_encoder = AirflowJsonEncoder

    csrf.init_app(app)

    db = SQLA(app)

    from airflow import api
    api.load_auth()
    api.API_AUTH.api_auth.init_app(app)

    Cache(app=app, config={'CACHE_TYPE': 'filesystem', 'CACHE_DIR': '/tmp'})

    from airflow.www.blueprints import routes
    app.register_blueprint(routes)

    configure_logging()
    configure_manifest_files(app)

    with app.app_context():
        from airflow.www.security import AirflowSecurityManager
        security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \
            AirflowSecurityManager

        if not issubclass(security_manager_class, AirflowSecurityManager):
            raise Exception(
                """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager,
                 not FAB's security manager.""")

        appbuilder = AppBuilder(app,
                                db.session if not session else session,
                                security_manager_class=security_manager_class,
                                base_template='airflow/master.html',
                                update_perms=conf.getboolean(
                                    'webserver', 'UPDATE_FAB_PERMS'))

        def init_views(appbuilder):
            from airflow.www import views
            # Remove the session from scoped_session registry to avoid
            # reusing a session with a disconnected connection
            appbuilder.session.remove()
            appbuilder.add_view_no_menu(views.Airflow())
            appbuilder.add_view_no_menu(views.DagModelView())
            appbuilder.add_view(views.DagRunModelView,
                                "DAG Runs",
                                category="Browse",
                                category_icon="fa-globe")
            appbuilder.add_view(views.JobModelView, "Jobs", category="Browse")
            appbuilder.add_view(views.LogModelView, "Logs", category="Browse")
            appbuilder.add_view(views.SlaMissModelView,
                                "SLA Misses",
                                category="Browse")
            appbuilder.add_view(views.TaskInstanceModelView,
                                "Task Instances",
                                category="Browse")
            appbuilder.add_view(views.ConfigurationView,
                                "Configurations",
                                category="Admin",
                                category_icon="fa-user")
            appbuilder.add_view(views.ConnectionModelView,
                                "Connections",
                                category="Admin")
            appbuilder.add_view(views.PoolModelView, "Pools", category="Admin")
            appbuilder.add_view(views.VariableModelView,
                                "Variables",
                                category="Admin")
            appbuilder.add_view(views.XComModelView, "XComs", category="Admin")

            if "dev" in version.version:
                airflow_doc_site = "https://airflow.readthedocs.io/en/latest"
            else:
                airflow_doc_site = 'https://airflow.apache.org/docs/{}'.format(
                    version.version)

            appbuilder.add_link("Website",
                                href='https://airflow.apache.org',
                                category="Docs",
                                category_icon="fa-globe")
            appbuilder.add_link("Documentation",
                                href=airflow_doc_site,
                                category="Docs",
                                category_icon="fa-cube")
            appbuilder.add_link("GitHub",
                                href='https://github.com/apache/airflow',
                                category="Docs")
            appbuilder.add_view(views.VersionView,
                                'Version',
                                category='About',
                                category_icon='fa-th')

            def integrate_plugins():
                """Integrate plugins to the context"""
                from airflow import plugins_manager

                plugins_manager.initialize_web_ui_plugins()

                for v in plugins_manager.flask_appbuilder_views:
                    log.debug("Adding view %s", v["name"])
                    appbuilder.add_view(v["view"],
                                        v["name"],
                                        category=v["category"])
                for ml in sorted(plugins_manager.flask_appbuilder_menu_links,
                                 key=lambda x: x["name"]):
                    log.debug("Adding menu link %s", ml["name"])
                    appbuilder.add_link(ml["name"],
                                        href=ml["href"],
                                        category=ml["category"],
                                        category_icon=ml["category_icon"])

            integrate_plugins()
            # Garbage collect old permissions/views after they have been modified.
            # Otherwise, when the name of a view or menu is changed, the framework
            # will add the new Views and Menus names to the backend, but will not
            # delete the old ones.

        def init_plugin_blueprints(app):
            from airflow.plugins_manager import flask_blueprints

            for bp in flask_blueprints:
                log.debug("Adding blueprint %s:%s", bp["name"],
                          bp["blueprint"].import_name)
                app.register_blueprint(bp["blueprint"])

        def init_error_handlers(app: Flask):
            from airflow.www import views
            app.register_error_handler(500, views.show_traceback)
            app.register_error_handler(404, views.circles)

        init_views(appbuilder)
        init_plugin_blueprints(app)
        init_error_handlers(app)

        if conf.getboolean('webserver', 'UPDATE_FAB_PERMS'):
            security_manager = appbuilder.sm
            security_manager.sync_roles()

        from airflow.www.api.experimental import endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            import importlib
            importlib.reload(e)

        app.register_blueprint(e.api_experimental,
                               url_prefix='/api/experimental')

        server_timezone = conf.get('core', 'default_timezone')
        if server_timezone == "system":
            server_timezone = pendulum.local_timezone().name
        elif server_timezone == "utc":
            server_timezone = "UTC"

        default_ui_timezone = conf.get('webserver', 'default_ui_timezone')
        if default_ui_timezone == "system":
            default_ui_timezone = pendulum.local_timezone().name
        elif default_ui_timezone == "utc":
            default_ui_timezone = "UTC"
        if not default_ui_timezone:
            default_ui_timezone = server_timezone

        @app.context_processor
        def jinja_globals():  # pylint: disable=unused-variable

            globals = {
                'server_timezone':
                server_timezone,
                'default_ui_timezone':
                default_ui_timezone,
                'hostname':
                socket.getfqdn() if conf.getboolean(
                    'webserver', 'EXPOSE_HOSTNAME', fallback=True) else
                'redact',
                'navbar_color':
                conf.get('webserver', 'NAVBAR_COLOR'),
                'log_fetch_delay_sec':
                conf.getint('webserver', 'log_fetch_delay_sec', fallback=2),
                'log_auto_tailing_offset':
                conf.getint('webserver',
                            'log_auto_tailing_offset',
                            fallback=30),
                'log_animation_speed':
                conf.getint('webserver', 'log_animation_speed', fallback=1000)
            }

            if 'analytics_tool' in conf.getsection('webserver'):
                globals.update({
                    'analytics_tool':
                    conf.get('webserver', 'ANALYTICS_TOOL'),
                    'analytics_id':
                    conf.get('webserver', 'ANALYTICS_ID')
                })

            return globals

        @app.before_request
        def before_request():
            _force_log_out_after = conf.getint('webserver',
                                               'FORCE_LOG_OUT_AFTER',
                                               fallback=0)
            if _force_log_out_after > 0:
                flask.session.permanent = True
                app.permanent_session_lifetime = datetime.timedelta(
                    minutes=_force_log_out_after)
                flask.session.modified = True
                flask.g.user = flask_login.current_user

        @app.after_request
        def apply_caching(response):
            _x_frame_enabled = conf.getboolean('webserver',
                                               'X_FRAME_ENABLED',
                                               fallback=True)
            if not _x_frame_enabled:
                response.headers["X-Frame-Options"] = "DENY"
            return response

        @app.teardown_appcontext
        def shutdown_session(exception=None):  # pylint: disable=unused-variable
            settings.Session.remove()

        @app.before_request
        def make_session_permanent():
            flask_session.permanent = True

    return app, appbuilder
logging.basicConfig(format='%(levelname)s:%(name)s:%(message)s')
logging.getLogger().setLevel(logging.DEBUG)
log = logging.getLogger('Database Migration to 1.3')

if len(sys.argv) < 2:
    log.info("Without typical app structure use parameter to config")
    log.info(
        "Use example for sqlite: python migrate_db_1.3.py sqlite:////home/user/application/app.db"
    )
    exit()

con_str = sys.argv[1]
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = con_str
db = SQLA(app)

add_column_stmt = {
    'mysql': 'ALTER TABLE %s ADD COLUMN %s %s',
    'sqlite': 'ALTER TABLE %s ADD COLUMN %s %s',
    'postgresql': 'ALTER TABLE %s ADD COLUMN %s %s'
}

mod_column_stmt = {
    'mysql': 'ALTER TABLE %s MODIFY COLUMN %s %s',
    'sqlite': '',
    'postgresql': 'ALTER TABLE %s ALTER COLUMN %s TYPE %s'
}


def check_engine_support(conn):
Beispiel #27
0
def create_app(config=None, testing=False, app_name="Airflow"):
    """Create a new instance of Airflow WWW app"""
    flask_app = Flask(__name__)
    flask_app.secret_key = conf.get('webserver', 'SECRET_KEY')

    if conf.has_option('webserver',
                       'SESSION_LIFETIME_DAYS') or conf.has_option(
                           'webserver', 'FORCE_LOG_OUT_AFTER'):
        logging.error(
            '`SESSION_LIFETIME_DAYS` option from `webserver` section has been '
            'renamed to `SESSION_LIFETIME_MINUTES`. New option allows to configure '
            'session lifetime in minutes. FORCE_LOG_OUT_AFTER option has been removed '
            'from `webserver` section. Please update your configuration.')
        # Stop gunicorn server https://github.com/benoitc/gunicorn/blob/20.0.4/gunicorn/arbiter.py#L526
        sys.exit(4)
    else:
        session_lifetime_minutes = conf.getint('webserver',
                                               'SESSION_LIFETIME_MINUTES',
                                               fallback=43200)
        logging.info('User session lifetime is set to %s minutes.',
                     session_lifetime_minutes)

    flask_app.config['PERMANENT_SESSION_LIFETIME'] = timedelta(
        minutes=session_lifetime_minutes)

    flask_app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    flask_app.config['APP_NAME'] = app_name
    flask_app.config['TESTING'] = testing
    flask_app.config['SQLALCHEMY_DATABASE_URI'] = conf.get(
        'core', 'SQL_ALCHEMY_CONN')
    flask_app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    flask_app.config['SESSION_COOKIE_HTTPONLY'] = True
    flask_app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    flask_app.config['SESSION_COOKIE_SAMESITE'] = conf.get(
        'webserver', 'COOKIE_SAMESITE')

    if config:
        flask_app.config.from_mapping(config)

    if 'SQLALCHEMY_ENGINE_OPTIONS' not in flask_app.config:
        flask_app.config[
            'SQLALCHEMY_ENGINE_OPTIONS'] = settings.prepare_engine_args()

    # Configure the JSON encoder used by `|tojson` filter from Flask
    flask_app.json_encoder = AirflowJsonEncoder

    csrf.init_app(flask_app)

    init_wsgi_middleware(flask_app)

    db = SQLA()
    db.session = settings.Session
    db.init_app(flask_app)

    init_dagbag(flask_app)

    init_api_experimental_auth(flask_app)

    Cache(app=flask_app,
          config={
              'CACHE_TYPE': 'filesystem',
              'CACHE_DIR': '/tmp'
          })

    init_flash_views(flask_app)

    configure_logging()
    configure_manifest_files(flask_app)

    with flask_app.app_context():
        init_appbuilder(flask_app)

        init_appbuilder_views(flask_app)
        init_appbuilder_links(flask_app)
        init_plugins(flask_app)
        init_error_handlers(flask_app)
        init_api_connexion(flask_app)
        init_api_experimental(flask_app)

        sync_appbuilder_roles(flask_app)

        init_jinja_globals(flask_app)
        init_xframe_protection(flask_app)
        init_permanent_session(flask_app)

    return flask_app
Beispiel #28
0
def create_app(config=None, session=None, testing=False, app_name="Airflow"):
    global app, appbuilder
    app = Flask(__name__)
    if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
        app.wsgi_app = ProxyFix(app.wsgi_app,
                                num_proxies=None,
                                x_for=1,
                                x_proto=1,
                                x_host=1,
                                x_port=1,
                                x_prefix=1)
    app.secret_key = conf.get('webserver', 'SECRET_KEY')

    app.config.from_pyfile(settings.WEBSERVER_CONFIG, silent=True)
    app.config['APP_NAME'] = app_name
    app.config['TESTING'] = testing
    app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

    app.config['SESSION_COOKIE_HTTPONLY'] = True
    app.config['SESSION_COOKIE_SECURE'] = conf.getboolean(
        'webserver', 'COOKIE_SECURE')
    app.config['SESSION_COOKIE_SAMESITE'] = conf.get('webserver',
                                                     'COOKIE_SAMESITE')

    if config:
        app.config.from_mapping(config)

    # Configure the JSON encoder used by `|tojson` filter from Flask
    app.json_encoder = AirflowJsonEncoder

    csrf.init_app(app)

    db = SQLA(app)

    from airflow import api
    api.load_auth()
    api.API_AUTH.api_auth.init_app(app)

    # flake8: noqa: F841
    cache = Cache(app=app,
                  config={
                      'CACHE_TYPE': 'filesystem',
                      'CACHE_DIR': '/tmp'
                  })

    from airflow.www.blueprints import routes
    app.register_blueprint(routes)

    configure_logging()
    configure_manifest_files(app)

    with app.app_context():

        from airflow.www.security import AirflowSecurityManager
        security_manager_class = app.config.get('SECURITY_MANAGER_CLASS') or \
            AirflowSecurityManager

        if not issubclass(security_manager_class, AirflowSecurityManager):
            raise Exception(
                """Your CUSTOM_SECURITY_MANAGER must now extend AirflowSecurityManager,
                 not FAB's security manager.""")

        appbuilder = AppBuilder(app,
                                db.session if not session else session,
                                security_manager_class=security_manager_class,
                                base_template='appbuilder/baselayout.html')

        def init_views(appbuilder):
            from airflow.www import views
            appbuilder.add_view_no_menu(views.Airflow())
            appbuilder.add_view_no_menu(views.DagModelView())
            appbuilder.add_view_no_menu(views.ConfigurationView())
            appbuilder.add_view_no_menu(views.VersionView())
            appbuilder.add_view(views.DagRunModelView,
                                "DAG Runs",
                                category="Browse",
                                category_icon="fa-globe")
            appbuilder.add_view(views.JobModelView, "Jobs", category="Browse")
            appbuilder.add_view(views.LogModelView, "Logs", category="Browse")
            appbuilder.add_view(views.SlaMissModelView,
                                "SLA Misses",
                                category="Browse")
            appbuilder.add_view(views.TaskInstanceModelView,
                                "Task Instances",
                                category="Browse")
            appbuilder.add_link("Configurations",
                                href='/configuration',
                                category="Admin",
                                category_icon="fa-user")
            appbuilder.add_view(views.ConnectionModelView,
                                "Connections",
                                category="Admin")
            appbuilder.add_view(views.PoolModelView, "Pools", category="Admin")
            appbuilder.add_view(views.VariableModelView,
                                "Variables",
                                category="Admin")
            appbuilder.add_view(views.XComModelView, "XComs", category="Admin")
            appbuilder.add_link("Documentation",
                                href='https://airflow.apache.org/',
                                category="Docs",
                                category_icon="fa-cube")
            appbuilder.add_link("GitHub",
                                href='https://github.com/apache/airflow',
                                category="Docs")
            appbuilder.add_link('Version',
                                href='/version',
                                category='About',
                                category_icon='fa-th')

            def integrate_plugins():
                """Integrate plugins to the context"""
                from airflow.plugins_manager import (
                    flask_appbuilder_views, flask_appbuilder_menu_links)

                for v in flask_appbuilder_views:
                    log.debug("Adding view %s", v["name"])
                    appbuilder.add_view(v["view"],
                                        v["name"],
                                        category=v["category"])
                for ml in sorted(flask_appbuilder_menu_links,
                                 key=lambda x: x["name"]):
                    log.debug("Adding menu link %s", ml["name"])
                    appbuilder.add_link(ml["name"],
                                        href=ml["href"],
                                        category=ml["category"],
                                        category_icon=ml["category_icon"])

            integrate_plugins()
            # Garbage collect old permissions/views after they have been modified.
            # Otherwise, when the name of a view or menu is changed, the framework
            # will add the new Views and Menus names to the backend, but will not
            # delete the old ones.

        def init_plugin_blueprints(app):
            from airflow.plugins_manager import flask_blueprints

            for bp in flask_blueprints:
                log.debug("Adding blueprint %s:%s", bp["name"],
                          bp["blueprint"].import_name)
                app.register_blueprint(bp["blueprint"])

        init_views(appbuilder)
        init_plugin_blueprints(app)

        security_manager = appbuilder.sm
        security_manager.sync_roles()

        from airflow.www.api.experimental import endpoints as e
        # required for testing purposes otherwise the module retains
        # a link to the default_auth
        if app.config['TESTING']:
            import importlib
            importlib.reload(e)

        app.register_blueprint(e.api_experimental,
                               url_prefix='/api/experimental')

        @app.context_processor
        def jinja_globals():
            return {
                'hostname': socket.getfqdn(),
                'navbar_color': conf.get('webserver', 'NAVBAR_COLOR'),
            }

        @app.teardown_appcontext
        def shutdown_session(exception=None):
            settings.Session.remove()

    return app, appbuilder
Beispiel #29
0
    def setUp(self):
        from flask import Flask
        from flask_appbuilder import AppBuilder
        from flask_appbuilder.models.sqla.interface import SQLAInterface
        from flask_appbuilder.views import ModelView
        from sqlalchemy.engine import Engine
        from sqlalchemy import event

        self.app = Flask(__name__)
        self.app.jinja_env.undefined = jinja2.StrictUndefined
        self.basedir = os.path.abspath(os.path.dirname(__file__))
        self.app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///"
        self.app.config["CSRF_ENABLED"] = False
        self.app.config["SECRET_KEY"] = "thisismyscretkey"
        self.app.config["WTF_CSRF_ENABLED"] = False
        self.app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
        self.app.config["FAB_ROLES"] = {
            "ReadOnly": [
                [".*", "can_list"],
                [".*", "can_show"]
            ]
        }
        logging.basicConfig(level=logging.ERROR)

        @event.listens_for(Engine, "connect")
        def set_sqlite_pragma(dbapi_connection, connection_record):
            # Will force sqllite contraint foreign keys
            cursor = dbapi_connection.cursor()
            cursor.execute("PRAGMA foreign_keys=ON")
            cursor.close()

        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app, self.db.session)

        sess = PSSession()

        class PSView(ModelView):
            datamodel = GenericInterface(PSModel, sess)
            base_permissions = ["can_list", "can_show"]
            list_columns = ["UID", "C", "CMD", "TIME"]
            search_columns = ["UID", "C", "CMD"]

        class Model2View(ModelView):
            datamodel = SQLAInterface(Model2)
            list_columns = [
                "field_integer",
                "field_float",
                "field_string",
                "field_method",
                "group.field_string",
            ]
            edit_form_query_rel_fields = {
                "group": [["field_string", FilterEqual, "G2"]]
            }
            add_form_query_rel_fields = {"group": [["field_string", FilterEqual, "G1"]]}

        class Model22View(ModelView):
            datamodel = SQLAInterface(Model2)
            list_columns = [
                "field_integer",
                "field_float",
                "field_string",
                "field_method",
                "group.field_string",
            ]
            add_exclude_columns = ["excluded_string"]
            edit_exclude_columns = ["excluded_string"]
            show_exclude_columns = ["excluded_string"]

        class Model1View(ModelView):
            datamodel = SQLAInterface(Model1)
            related_views = [Model2View]
            list_columns = ["field_string", "field_file"]

        class Model3View(ModelView):
            datamodel = SQLAInterface(Model3)
            list_columns = ["pk1", "pk2", "field_string"]
            add_columns = ["pk1", "pk2", "field_string"]
            edit_columns = ["pk1", "pk2", "field_string"]

        class Model1CompactView(CompactCRUDMixin, ModelView):
            datamodel = SQLAInterface(Model1)

        class Model3CompactView(CompactCRUDMixin, ModelView):
            datamodel = SQLAInterface(Model3)

        class Model1ViewWithRedirects(ModelView):
            datamodel = SQLAInterface(Model1)
            obj_id = 1

            def post_add_redirect(self):
                return redirect(
                    "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID)
                )

            def post_edit_redirect(self):
                return redirect(
                    "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID)
                )

            def post_delete_redirect(self):
                return redirect(
                    "/model1viewwithredirects/show/{0}".format(REDIRECT_OBJ_ID)
                )

        class Model1Filtered1View(ModelView):
            datamodel = SQLAInterface(Model1)
            base_filters = [["field_string", FilterStartsWith, "a"]]

        class Model1MasterView(MasterDetailView):
            datamodel = SQLAInterface(Model1)
            related_views = [Model2View]

        class Model1Filtered2View(ModelView):
            datamodel = SQLAInterface(Model1)
            base_filters = [["field_integer", FilterEqual, 0]]

        class Model2ChartView(ChartView):
            datamodel = SQLAInterface(Model2)
            chart_title = "Test Model1 Chart"
            group_by_columns = ["field_string"]

        class Model2GroupByChartView(GroupByChartView):
            datamodel = SQLAInterface(Model2)
            chart_title = "Test Model1 Chart"

            definitions = [
                {
                    "group": "field_string",
                    "series": [
                        (
                            aggregate_sum,
                            "field_integer",
                            aggregate_avg,
                            "field_integer",
                            aggregate_count,
                            "field_integer",
                        )
                    ],
                }
            ]

        class Model2DirectByChartView(DirectByChartView):
            datamodel = SQLAInterface(Model2)
            chart_title = "Test Model1 Chart"
            list_title = ""

            definitions = [
                {"group": "field_string", "series": ["field_integer", "field_float"]}
            ]

        class Model2TimeChartView(TimeChartView):
            datamodel = SQLAInterface(Model2)
            chart_title = "Test Model1 Chart"
            group_by_columns = ["field_date"]

        class Model2DirectChartView(DirectChartView):
            datamodel = SQLAInterface(Model2)
            chart_title = "Test Model1 Chart"
            direct_columns = {"stat1": ("group", "field_integer")}

        class Model1MasterChartView(MasterDetailView):
            datamodel = SQLAInterface(Model1)
            related_views = [Model2DirectByChartView]

        class Model1FormattedView(ModelView):
            datamodel = SQLAInterface(Model1)
            list_columns = ["field_string"]
            show_columns = ["field_string"]
            formatters_columns = {"field_string": lambda x: "FORMATTED_STRING"}

        class ModelWithEnumsView(ModelView):
            datamodel = SQLAInterface(ModelWithEnums)

        self.appbuilder.add_view(Model1View, "Model1", category="Model1")
        self.appbuilder.add_view(
            Model1ViewWithRedirects, "Model1ViewWithRedirects", category="Model1"
        )
        self.appbuilder.add_view(Model1CompactView, "Model1Compact", category="Model1")
        self.appbuilder.add_view(Model1MasterView, "Model1Master", category="Model1")
        self.appbuilder.add_view(
            Model1MasterChartView, "Model1MasterChart", category="Model1"
        )
        self.appbuilder.add_view(
            Model1Filtered1View, "Model1Filtered1", category="Model1"
        )
        self.appbuilder.add_view(
            Model1Filtered2View, "Model1Filtered2", category="Model1"
        )
        self.appbuilder.add_view(
            Model1FormattedView, "Model1FormattedView", category="Model1FormattedView"
        )

        self.appbuilder.add_view(Model2View, "Model2")
        self.appbuilder.add_view(Model22View, "Model22")
        self.appbuilder.add_view(Model2View, "Model2 Add", href="/model2view/add")
        self.appbuilder.add_view(Model2ChartView, "Model2 Chart")
        self.appbuilder.add_view(Model2GroupByChartView, "Model2 Group By Chart")
        self.appbuilder.add_view(Model2DirectByChartView, "Model2 Direct By Chart")
        self.appbuilder.add_view(Model2TimeChartView, "Model2 Time Chart")
        self.appbuilder.add_view(Model2DirectChartView, "Model2 Direct Chart")

        self.appbuilder.add_view(Model3View, "Model3")
        self.appbuilder.add_view(Model3CompactView, "Model3Compact")

        self.appbuilder.add_view(ModelWithEnumsView, "ModelWithEnums")

        self.appbuilder.add_view(PSView, "Generic DS PS View", category="PSView")
        role_admin = self.appbuilder.sm.find_role("Admin")
        self.appbuilder.sm.add_user(
            "admin", "admin", "user", "*****@*****.**", role_admin, "general"
        )
        role_read_only = self.appbuilder.sm.find_role("ReadOnly")
        self.appbuilder.sm.add_user(
            USERNAME_READONLY,
            "readonly",
            "readonly",
            "*****@*****.**",
            role_read_only,
            PASSWORD_READONLY
        )
Beispiel #30
0
class TestSecurity(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        settings.configure_orm()
        cls.session = settings.Session
        cls.app = application.create_app(testing=True)
        cls.appbuilder = cls.app.appbuilder  # pylint: disable=no-member
        cls.app.config['WTF_CSRF_ENABLED'] = False
        cls.security_manager = cls.appbuilder.sm
        cls.delete_roles()

    def setUp(self):
        clear_db_runs()
        clear_db_dags()
        self.db = SQLA(self.app)
        self.appbuilder.add_view(SomeBaseView,
                                 "SomeBaseView",
                                 category="BaseViews")
        self.appbuilder.add_view(SomeModelView,
                                 "SomeModelView",
                                 category="ModelViews")

        log.debug("Complete setup!")

    @classmethod
    def delete_roles(cls):
        for role_name in [
                'team-a', 'MyRole1', 'MyRole5', 'Test_Role', 'MyRole3',
                'MyRole2'
        ]:
            fab_utils.delete_role(cls.app, role_name)

    def expect_user_is_in_role(self, user, rolename):
        self.security_manager.init_role(rolename, [])
        role = self.security_manager.find_role(rolename)
        if not role:
            self.security_manager.add_role(rolename)
            role = self.security_manager.find_role(rolename)
        user.roles = [role]
        self.security_manager.update_user(user)

    def assert_user_has_dag_perms(self, perms, dag_id, user=None):
        for perm in perms:
            self.assertTrue(
                self._has_dag_perm(perm, dag_id, user),
                f"User should have '{perm}' on DAG '{dag_id}'",
            )

    def assert_user_does_not_have_dag_perms(self, dag_id, perms, user=None):
        for perm in perms:
            self.assertFalse(
                self._has_dag_perm(perm, dag_id, user),
                f"User should not have '{perm}' on DAG '{dag_id}'",
            )

    def _has_dag_perm(self, perm, dag_id, user):
        # if not user:
        #     user = self.user
        return self.security_manager.has_access(
            perm, self.security_manager.prefixed_dag_id(dag_id), user)

    def tearDown(self):
        clear_db_runs()
        clear_db_dags()
        self.appbuilder = None
        self.app = None
        self.db = None
        log.debug("Complete teardown!")

    def test_init_role_baseview(self):
        role_name = 'MyRole3'
        role_perms = [('can_some_action', 'SomeBaseView')]
        self.security_manager.init_role(role_name, perms=role_perms)
        role = self.appbuilder.sm.find_role(role_name)
        self.assertIsNotNone(role)
        self.assertEqual(len(role_perms), len(role.permissions))

    def test_init_role_modelview(self):
        role_name = 'MyRole2'
        role_perms = [
            ('can_list', 'SomeModelView'),
            ('can_show', 'SomeModelView'),
            ('can_add', 'SomeModelView'),
            (permissions.ACTION_CAN_EDIT, 'SomeModelView'),
            (permissions.ACTION_CAN_DELETE, 'SomeModelView'),
        ]
        self.security_manager.init_role(role_name, role_perms)
        role = self.appbuilder.sm.find_role(role_name)
        self.assertIsNotNone(role)
        self.assertEqual(len(role_perms), len(role.permissions))

    def test_update_and_verify_permission_role(self):
        role_name = 'Test_Role'
        self.security_manager.init_role(role_name, [])
        role = self.security_manager.find_role(role_name)

        perm = self.security_manager.find_permission_view_menu(
            permissions.ACTION_CAN_EDIT, 'RoleModelView')
        self.security_manager.add_permission_role(role, perm)
        role_perms_len = len(role.permissions)

        self.security_manager.init_role(role_name, [])
        new_role_perms_len = len(role.permissions)

        self.assertEqual(role_perms_len, new_role_perms_len)

    def test_get_user_roles(self):
        user = mock.MagicMock()
        user.is_anonymous = False
        roles = self.appbuilder.sm.find_role('Admin')
        user.roles = roles
        self.assertEqual(self.security_manager.get_user_roles(user), roles)

    @mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles')
    def test_get_all_permissions_views(self, mock_get_user_roles):
        role_name = 'MyRole5'
        role_perm = 'can_some_action'
        role_vm = 'SomeBaseView'
        username = '******'

        with self.app.app_context():
            user = fab_utils.create_user(
                self.app,
                username,
                role_name,
                permissions=[
                    (role_perm, role_vm),
                ],
            )
            role = user.roles[0]
            mock_get_user_roles.return_value = [role]

            self.assertEqual(self.security_manager.get_all_permissions_views(),
                             {(role_perm, role_vm)})

            mock_get_user_roles.return_value = []
            self.assertEqual(
                len(self.security_manager.get_all_permissions_views()), 0)

    def test_get_accessible_dag_ids(self):
        role_name = 'MyRole1'
        permission_action = [permissions.ACTION_CAN_READ]
        dag_id = 'dag_id'
        username = "******"

        user = fab_utils.create_user(
            self.app,
            username,
            role_name,
            permissions=[
                (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
                (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
            ],
        )

        dag_model = DagModel(dag_id=dag_id,
                             fileloc="/tmp/dag_.py",
                             schedule_interval="2 2 * * *")
        self.session.add(dag_model)
        self.session.commit()

        self.security_manager.sync_perm_for_dag(  # type: ignore  # pylint: disable=no-member
            dag_id,
            access_control={role_name: permission_action})

        self.assertEqual(self.security_manager.get_accessible_dag_ids(user),
                         {'dag_id'})

    @mock.patch('airflow.www.security.AirflowSecurityManager._has_view_access')
    def test_has_access(self, mock_has_view_access):
        user = mock.MagicMock()
        user.is_anonymous = False
        mock_has_view_access.return_value = True
        self.assertTrue(self.security_manager.has_access('perm', 'view', user))

    def test_sync_perm_for_dag_creates_permissions_on_view_menus(self):
        test_dag_id = 'TEST_DAG'
        prefixed_test_dag_id = f'DAG:{test_dag_id}'
        self.security_manager.sync_perm_for_dag(test_dag_id,
                                                access_control=None)
        self.assertIsNotNone(
            self.security_manager.find_permission_view_menu(
                permissions.ACTION_CAN_READ, prefixed_test_dag_id))
        self.assertIsNotNone(
            self.security_manager.find_permission_view_menu(
                permissions.ACTION_CAN_EDIT, prefixed_test_dag_id))

    @mock.patch('airflow.www.security.AirflowSecurityManager._has_perm')
    @mock.patch('airflow.www.security.AirflowSecurityManager._has_role')
    def test_has_all_dag_access(self, mock_has_role, mock_has_perm):
        mock_has_role.return_value = True
        self.assertTrue(self.security_manager.has_all_dags_access())

        mock_has_role.return_value = False
        mock_has_perm.return_value = False
        self.assertFalse(self.security_manager.has_all_dags_access())

        mock_has_perm.return_value = True
        self.assertTrue(self.security_manager.has_all_dags_access())

    def test_access_control_with_non_existent_role(self):
        with self.assertRaises(AirflowException) as context:
            self.security_manager.sync_perm_for_dag(
                dag_id='access-control-test',
                access_control={
                    'this-role-does-not-exist':
                    [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]
                },
            )
        self.assertIn("role does not exist", str(context.exception))

    def test_all_dag_access_doesnt_give_non_dag_access(self):
        username = '******'
        role_name = 'dag_access_role'
        with self.app.app_context():
            user = fab_utils.create_user(
                self.app,
                username,
                role_name,
                permissions=[
                    (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
                    (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
                ],
            )
            self.assertTrue(
                self.security_manager.has_access(permissions.ACTION_CAN_READ,
                                                 permissions.RESOURCE_DAG,
                                                 user))
            self.assertFalse(
                self.security_manager.has_access(
                    permissions.ACTION_CAN_READ,
                    permissions.RESOURCE_TASK_INSTANCE, user))

    def test_access_control_with_invalid_permission(self):
        invalid_permissions = [
            'can_varimport',  # a real permission, but not a member of DAG_PERMS
            'can_eat_pudding',  # clearly not a real permission
        ]
        username = "******"
        user = fab_utils.create_user(
            self.app,
            username=username,
            role_name='team-a',
        )
        for permission in invalid_permissions:
            self.expect_user_is_in_role(user, rolename='team-a')
            with self.assertRaises(AirflowException) as context:
                self.security_manager.sync_perm_for_dag(
                    'access_control_test',
                    access_control={'team-a': {permission}})
            self.assertIn("invalid permissions", str(context.exception))

    def test_access_control_is_set_on_init(self):
        username = '******'
        role_name = 'team-a'
        with self.app.app_context():
            user = fab_utils.create_user(
                self.app,
                username,
                role_name,
                permissions=[],
            )
            self.expect_user_is_in_role(user, rolename='team-a')
            self.security_manager.sync_perm_for_dag(
                'access_control_test',
                access_control={
                    'team-a':
                    [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ]
                },
            )
            self.assert_user_has_dag_perms(
                perms=[
                    permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ
                ],
                dag_id='access_control_test',
                user=user,
            )

            self.expect_user_is_in_role(user, rolename='NOT-team-a')
            self.assert_user_does_not_have_dag_perms(
                perms=[
                    permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ
                ],
                dag_id='access_control_test',
                user=user,
            )

    def test_access_control_stale_perms_are_revoked(self):
        username = '******'
        role_name = 'team-a'
        with self.app.app_context():
            user = fab_utils.create_user(
                self.app,
                username,
                role_name,
                permissions=[],
            )
            self.expect_user_is_in_role(user, rolename='team-a')
            self.security_manager.sync_perm_for_dag(
                'access_control_test', access_control={'team-a': READ_WRITE})
            self.assert_user_has_dag_perms(perms=READ_WRITE,
                                           dag_id='access_control_test',
                                           user=user)

            self.security_manager.sync_perm_for_dag(
                'access_control_test', access_control={'team-a': READ_ONLY})
            self.assert_user_has_dag_perms(perms=[permissions.ACTION_CAN_READ],
                                           dag_id='access_control_test',
                                           user=user)
            self.assert_user_does_not_have_dag_perms(
                perms=[permissions.ACTION_CAN_EDIT],
                dag_id='access_control_test',
                user=user)

    def test_no_additional_dag_permission_views_created(self):
        ab_perm_view_role = sqla_models.assoc_permissionview_role

        self.security_manager.sync_roles()
        num_pv_before = self.db.session().query(ab_perm_view_role).count()
        self.security_manager.sync_roles()
        num_pv_after = self.db.session().query(ab_perm_view_role).count()
        self.assertEqual(num_pv_before, num_pv_after)

    def test_override_role_vm(self):
        test_security_manager = MockSecurityManager(appbuilder=self.appbuilder)
        self.assertEqual(len(test_security_manager.VIEWER_VMS), 1)
        self.assertEqual(test_security_manager.VIEWER_VMS, {'Airflow'})
Beispiel #31
0
    def parse_manifest_json(self) -> None:
        try:
            with open(self.manifest_file, "r") as f:
                # the manifest includes non-entry files we only need entries in
                # templates
                full_manifest = json.load(f)
                self.manifest = full_manifest.get("entrypoints", {})
        except Exception:  # pylint: disable=broad-except
            pass

    def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]:
        if self.app and self.app.debug:
            self.parse_manifest_json()
        return self.manifest.get(bundle, {}).get(asset_type, [])


APP_DIR = os.path.dirname(__file__)
appbuilder = AppBuilder(update_perms=False)
cache_manager = CacheManager()
celery_app = celery.Celery()
db = SQLA()
_event_logger: Dict[str, Any] = {}
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
jinja_context_manager = JinjaContextManager()
manifest_processor = UIManifestProcessor(APP_DIR)
migrate = Migrate()
results_backend_manager = ResultsBackendManager()
security_manager = LocalProxy(lambda: appbuilder.sm)
talisman = Talisman()
class TestSecurity(unittest.TestCase):
    def setUp(self):
        self.app = Flask(__name__)
        self.app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///'
        self.app.config['SECRET_KEY'] = 'secret_key'
        self.app.config['CSRF_ENABLED'] = False
        self.app.config['WTF_CSRF_ENABLED'] = False
        self.db = SQLA(self.app)
        self.appbuilder = AppBuilder(self.app,
                                     self.db.session,
                                     security_manager_class=AirflowSecurityManager)
        self.security_manager = self.appbuilder.sm
        self.appbuilder.add_view(SomeBaseView, "SomeBaseView", category="BaseViews")
        self.appbuilder.add_view(SomeModelView, "SomeModelView", category="ModelViews")
        role_admin = self.security_manager.find_role('Admin')
        self.user = self.appbuilder.sm.add_user('admin', 'admin', 'user', '*****@*****.**',
                                                role_admin, 'general')
        log.debug("Complete setup!")

    def expect_user_is_in_role(self, user, rolename):
        self.security_manager.init_role(rolename, [], [])
        role = self.security_manager.find_role(rolename)
        if not role:
            self.security_manager.add_role(rolename)
            role = self.security_manager.find_role(rolename)
        user.roles = [role]
        self.security_manager.update_user(user)

    def assert_user_has_dag_perms(self, perms, dag_id):
        for perm in perms:
            self.assertTrue(
                self._has_dag_perm(perm, dag_id),
                "User should have '{}' on DAG '{}'".format(perm, dag_id))

    def assert_user_does_not_have_dag_perms(self, dag_id, perms):
        for perm in perms:
            self.assertFalse(
                self._has_dag_perm(perm, dag_id),
                "User should not have '{}' on DAG '{}'".format(perm, dag_id))

    def _has_dag_perm(self, perm, dag_id):
        return self.security_manager.has_access(
            perm,
            dag_id,
            self.user)

    def tearDown(self):
        self.appbuilder = None
        self.app = None
        self.db = None
        log.debug("Complete teardown!")

    def test_init_role_baseview(self):
        role_name = 'MyRole1'
        role_perms = ['can_some_action']
        role_vms = ['SomeBaseView']
        self.security_manager.init_role(role_name, role_vms, role_perms)
        role = self.appbuilder.sm.find_role(role_name)
        self.assertIsNotNone(role)
        self.assertEqual(len(role_perms), len(role.permissions))

    def test_init_role_modelview(self):
        role_name = 'MyRole2'
        role_perms = ['can_list', 'can_show', 'can_add', 'can_edit', 'can_delete']
        role_vms = ['SomeModelView']
        self.security_manager.init_role(role_name, role_vms, role_perms)
        role = self.appbuilder.sm.find_role(role_name)
        self.assertIsNotNone(role)
        self.assertEqual(len(role_perms), len(role.permissions))

    def test_update_and_verify_permission_role(self):
        role_name = 'Test_Role'
        self.security_manager.init_role(role_name, [], [])
        role = self.security_manager.find_role(role_name)

        perm = self.security_manager.\
            find_permission_view_menu('can_edit', 'RoleModelView')
        self.security_manager.add_permission_role(role, perm)
        role_perms_len = len(role.permissions)

        self.security_manager.init_role(role_name, [], [])
        new_role_perms_len = len(role.permissions)

        self.assertEqual(role_perms_len, new_role_perms_len)

    def test_get_user_roles(self):
        user = mock.MagicMock()
        user.is_anonymous = False
        roles = self.appbuilder.sm.find_role('Admin')
        user.roles = roles
        self.assertEqual(self.security_manager.get_user_roles(user), roles)

    @mock.patch('airflow.www.security.AirflowSecurityManager.get_user_roles')
    def test_get_all_permissions_views(self, mock_get_user_roles):
        role_name = 'MyRole1'
        role_perms = ['can_some_action']
        role_vms = ['SomeBaseView']
        self.security_manager.init_role(role_name, role_vms, role_perms)
        role = self.security_manager.find_role(role_name)

        mock_get_user_roles.return_value = [role]
        self.assertEqual(self.security_manager
                         .get_all_permissions_views(),
                         {('can_some_action', 'SomeBaseView')})

        mock_get_user_roles.return_value = []
        self.assertEqual(len(self.security_manager
                             .get_all_permissions_views()), 0)

    @mock.patch('airflow.www.security.AirflowSecurityManager'
                '.get_all_permissions_views')
    @mock.patch('airflow.www.security.AirflowSecurityManager'
                '.get_user_roles')
    def test_get_accessible_dag_ids(self, mock_get_user_roles,
                                    mock_get_all_permissions_views):
        user = mock.MagicMock()
        role_name = 'MyRole1'
        role_perms = ['can_dag_read']
        role_vms = ['dag_id']
        self.security_manager.init_role(role_name, role_vms, role_perms)
        role = self.security_manager.find_role(role_name)
        user.roles = [role]
        user.is_anonymous = False
        mock_get_all_permissions_views.return_value = {('can_dag_read', 'dag_id')}

        mock_get_user_roles.return_value = [role]
        self.assertEqual(self.security_manager
                         .get_accessible_dag_ids(user), set(['dag_id']))

    @mock.patch('airflow.www.security.AirflowSecurityManager._has_view_access')
    def test_has_access(self, mock_has_view_access):
        user = mock.MagicMock()
        user.is_anonymous = False
        mock_has_view_access.return_value = True
        self.assertTrue(self.security_manager.has_access('perm', 'view', user))

    def test_sync_perm_for_dag_creates_permissions_on_view_menus(self):
        test_dag_id = 'TEST_DAG'
        self.security_manager.sync_perm_for_dag(test_dag_id, access_control=None)
        for dag_perm in self.security_manager.DAG_PERMS:
            self.assertIsNotNone(self.security_manager.
                                 find_permission_view_menu(dag_perm, test_dag_id))

    @mock.patch('airflow.www.security.AirflowSecurityManager._has_perm')
    @mock.patch('airflow.www.security.AirflowSecurityManager._has_role')
    def test_has_all_dag_access(self, mock_has_role, mock_has_perm):
        mock_has_role.return_value = True
        self.assertTrue(self.security_manager.has_all_dags_access())

        mock_has_role.return_value = False
        mock_has_perm.return_value = False
        self.assertFalse(self.security_manager.has_all_dags_access())

        mock_has_perm.return_value = True
        self.assertTrue(self.security_manager.has_all_dags_access())

    def test_access_control_with_non_existent_role(self):
        with self.assertRaises(AirflowException) as context:
            self.security_manager.sync_perm_for_dag(
                dag_id='access-control-test',
                access_control={
                    'this-role-does-not-exist': ['can_dag_edit', 'can_dag_read']
                })
        self.assertIn("role does not exist", str(context.exception))

    def test_access_control_with_invalid_permission(self):
        invalid_permissions = [
            'can_varimport',  # a real permission, but not a member of DAG_PERMS
            'can_eat_pudding',  # clearly not a real permission
        ]
        for permission in invalid_permissions:
            self.expect_user_is_in_role(self.user, rolename='team-a')
            with self.assertRaises(AirflowException) as context:
                self.security_manager.sync_perm_for_dag(
                    'access_control_test',
                    access_control={
                        'team-a': {permission}
                    })
            self.assertIn("invalid permissions", str(context.exception))

    def test_access_control_is_set_on_init(self):
        self.expect_user_is_in_role(self.user, rolename='team-a')
        self.security_manager.sync_perm_for_dag(
            'access_control_test',
            access_control={
                'team-a': ['can_dag_edit', 'can_dag_read']
            })
        self.assert_user_has_dag_perms(
            perms=['can_dag_edit', 'can_dag_read'],
            dag_id='access_control_test',
        )

        self.expect_user_is_in_role(self.user, rolename='NOT-team-a')
        self.assert_user_does_not_have_dag_perms(
            perms=['can_dag_edit', 'can_dag_read'],
            dag_id='access_control_test',
        )

    def test_access_control_stale_perms_are_revoked(self):
        READ_WRITE = {'can_dag_read', 'can_dag_edit'}
        READ_ONLY = {'can_dag_read'}

        self.expect_user_is_in_role(self.user, rolename='team-a')
        self.security_manager.sync_perm_for_dag(
            'access_control_test',
            access_control={'team-a': READ_WRITE})
        self.assert_user_has_dag_perms(
            perms=READ_WRITE,
            dag_id='access_control_test',
        )

        self.security_manager.sync_perm_for_dag(
            'access_control_test',
            access_control={'team-a': READ_ONLY})
        self.assert_user_has_dag_perms(
            perms=['can_dag_read'],
            dag_id='access_control_test',
        )
        self.assert_user_does_not_have_dag_perms(
            perms=['can_dag_edit'],
            dag_id='access_control_test',
        )

    def test_no_additional_dag_permission_views_created(self):
        ab_perm_view_role = sqla_models.assoc_permissionview_role

        self.security_manager.sync_roles()
        num_pv_before = self.db.session().query(ab_perm_view_role).count()
        self.security_manager.sync_roles()
        num_pv_after = self.db.session().query(ab_perm_view_role).count()
        self.assertEqual(num_pv_before, num_pv_after)

    def test_override_role_vm(self):
        test_security_manager = TestSecurityManager(appbuilder=self.appbuilder)
        self.assertEqual(len(test_security_manager.VIEWER_VMS), 1)
        self.assertEqual(test_security_manager.VIEWER_VMS, {'Airflow'})
import logging
from flask import Flask
from flask_appbuilder import AppBuilder, SQLA
from sqlalchemy.engine import Engine
from sqlalchemy import event
from .indexview import FABView

logging.basicConfig(format='%(asctime)s:%(levelname)s:%(name)s:%(message)s')
logging.getLogger().setLevel(logging.DEBUG)

app = Flask(__name__)
app.config.from_object('config')
db = SQLA(app)
appbuilder = AppBuilder(app, db.session, base_template='mybase.html', indexview=FABView)


"""
Only include this for SQLLite constraints
"""
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
    cursor = dbapi_connection.cursor()
    cursor.execute("PRAGMA foreign_keys=ON")
    cursor.close()
    

from app import views, data
from app import api

db.create_all()
data.fill_gender()
Beispiel #34
0
def upgrade_db(config, backup):
    """
        Upgrade your database after F.A.B. upgrade if necessary (SQLAlchemy only)

        Version 1.3.0 upgrade needs database upgrade, read version migration on docs for
        further details.
    """
    from flask import Flask
    from flask_appbuilder import SQLA
    from flask_appbuilder.security.sqla.models import User
    from sqlalchemy import Column, Integer, ForeignKey
    from sqlalchemy.orm import relationship


    class UpgProxyUser(User):
        role_id = Column(Integer, ForeignKey('ab_role.id'))
        role = relationship('Role')

    sequenceremap={ 'seq_ab_permission_pk':'ab_permission_id_seq',
                    'seq_ab_view_menu_pk' :'ab_view_menu_id_seq',
                    'seq_permission_view_pk': 'ab_permission_view_id_seq',
                    'seq_ab_permission_view_role_pk': 'ab_permission_view_role_id_seq',
                    'seq_ab_role_pk rename': 'ab_role_id_seq',
                    'seq_ab_user_role_pk': 'ab_user_role_id_seq',
                    'seq_ab_user_pk': 'ab_user_id_seq',
                    'seq_ab_register_user_pk': 'ab_register_user_id_seq'
                }

    del_column_stmt = {'mysql': 'ALTER TABLE %s DROP COLUMN %s',
                    'postgresql': 'ALTER TABLE %s DROP COLUMN %s',
                    'oracle': 'ALTER TABLE %s DROP COLUMN %s',
                    'mssql': 'ALTER TABLE %s DROP COLUMN %s'}

    del_foreign_stmt = {'mysql': 'ALTER TABLE %s DROP FOREIGN KEY %s',
                    'postgresql': 'ALTER TABLE %s DROP CONSTRAINT %s',
                    'oracle': 'ALTER TABLE %s DROP CONSTRAINT %s',
                    'mssql': 'ALTER TABLE %s DROP CONSTRAINT %s'}

    def del_column(conn, table_name, column_name):
        try:
            if conn.engine.name in del_column_stmt:
                click.echo(click.style("Going to delete Column {0} on {1}".format(column_name, table_name), fg='green'))
                conn.engine.execute(del_column_stmt[conn.engine.name] % (table_name, column_name))
                click.echo(click.style("Deleted Column {0} on {1}".format(column_name, table_name), fg='green'))
            else:
                click.echo(click.style("Engine {0} not supported for auto upgrade, del column {1}.{2} yourself" \
                                       .format(conn.engine.name, table_name, column_name), fg='red'))
        except Exception as e:
            click.echo(click.style("Error deleting Column {0} on {1}: {2}".format(column_name, table_name, str(e)), fg='red'))

    def del_foreign_key(conn, table_name, column_name):
        try:
            if conn.engine.name in del_foreign_stmt:
                click.echo(click.style("Going to drop FK {0} on {1}".format(column_name, table_name), fg='green'))
                conn.engine.execute(del_foreign_stmt[conn.engine.name] % (table_name, column_name))
                click.echo(click.style("Droped FK {0} on {1}".format(column_name, table_name), fg='green'))
            else:
                click.echo(click.style("Engine {0} not supported for auto upgrade, del FK {1}.{2} yourself" \
                                       .format(conn.engine.name, table_name, column_name), fg='red'))
        except Exception as e:
            click.echo(click.style("Error droping FK {0} on {1}: {2}".format(column_name, table_name, str(e)), fg='red'))


    if not backup.lower() in ('yes', 'y'):
        click.echo(click.style('Please backup first', fg='red'))
        exit(0)
    sys.path.append(os.getcwd())
    
    app = Flask(__name__)
    app.config.from_object(config)
    db = SQLA(app)
    db.create_all()

    # Upgrade Users append role on roles, allows 1.3.0 multiple roles for user.
    click.echo(click.style('Beginning user migration, hope you have backed up first', fg='green'))
    try:
        for user in db.session.query(UpgProxyUser).all():
            user.roles.append(user.role)
            db.session.commit()
            click.echo(click.style('Altered user {0}'.format(user.username), fg='green'))
    except:
            click.echo(click.style('Error on Upgrade, DB is probably already compliant', fg='green'))
            exit(0)
    db.session.remove()

    if db.engine.name == 'sqlite':
        click.echo(click.style('\n------------------\nTo finish the upgrade you must download and execute the following sql\n\
        Download from https://github.com/dpgaspar/Flask-AppBuilder/tree/master/bin/sqlite_upgrade_1.3.sql', fg='green'))
        exit(0)

    del_foreign_key(db, 'ab_user', 'ab_user_role_id_fkey')
    del_column(db, 'ab_user', 'role_id')
    
    # POSTGRESQL
    if db.engine.name == 'postgresql':
        for seq in sequenceremap.keys():
            try:
                checksequence=db.engine.execute("SELECT 0 from pg_class where relname=%s;", seq)
                if checksequence.fetchone() is not None:
                    db.engine.execute("alter sequence %s rename to %s;" % (seq, sequenceremap[seq]))
                click.echo(click.style('Altered sequence from {0} {1}'.format(seq,sequenceremap[seq]), fg='green'))
            except:
                click.echo(click.style('Error Altering sequence from {0} to {1}'.format(seq,sequenceremap[seq]), fg='red'))
    # ORACLE
    if db.engine.name == 'oracle':
        click.echo(click.style('Your using PostgreSQL going to change your sequence names', fg='green'))
        for seq in sequenceremap.keys():
            try:
                db.engine.execute("rename %s to %s;" % (seq, sequenceremap[seq]))
                click.echo(click.style('Altered sequence from {0} to {1}'.format(seq, sequenceremap[seq]), fg='green'))
            except:
                click.echo(click.style('Error Altering sequence from {0} to {1}'.format(seq, sequenceremap[seq]), fg='red'))