Example #1
0
    def test_clean_requests_after_db_grant(self):
        session = db.session

        # Case 3. Two access requests from gamma and gamma2
        # Gamma gets database access, gamma2 access request granted
        # Check if request by gamma has been deleted

        gamma_user = security_manager.find_user(username='******')
        access_request1 = create_access_request(
            session, 'table', 'energy_usage', TEST_ROLE_1, 'gamma')
        create_access_request(
            session, 'table', 'energy_usage', TEST_ROLE_2, 'gamma2')
        ds_1_id = access_request1.datasource_id
        # gamma gets granted database access
        database = session.query(models.Database).first()

        security_manager.merge_perm('database_access', database.perm)
        ds_perm_view = security_manager.find_permission_view_menu(
            'database_access', database.perm)
        security_manager.add_permission_role(
            security_manager.find_role(DB_ACCESS_ROLE), ds_perm_view)
        gamma_user.roles.append(security_manager.find_role(DB_ACCESS_ROLE))
        session.commit()
        access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
        self.assertTrue(access_requests)
        # gamma2 request gets fulfilled
        self.client.get(EXTEND_ROLE_REQUEST.format(
            'table', ds_1_id, 'gamma2', TEST_ROLE_2))
        access_requests = self.get_access_requests('gamma', 'table', ds_1_id)

        self.assertFalse(access_requests)
        gamma_user = security_manager.find_user(username='******')
        gamma_user.roles.remove(security_manager.find_role(DB_ACCESS_ROLE))
        session.commit()
Example #2
0
    def test_clean_requests_after_alpha_grant(self):
        session = db.session

        # Case 2. Two access requests from gamma and gamma2
        # Gamma becomes alpha, gamma2 gets granted
        # Check if request by gamma has been deleted

        access_request1 = create_access_request(
            session, 'table', 'birth_names', TEST_ROLE_1, 'gamma')
        create_access_request(
            session, 'table', 'birth_names', TEST_ROLE_2, 'gamma2')
        ds_1_id = access_request1.datasource_id
        # gamma becomes alpha
        alpha_role = security_manager.find_role('Alpha')
        gamma_user = security_manager.find_user(username='******')
        gamma_user.roles.append(alpha_role)
        session.commit()
        access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
        self.assertTrue(access_requests)
        self.client.get(EXTEND_ROLE_REQUEST.format(
            'table', ds_1_id, 'gamma2', TEST_ROLE_2))
        access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
        self.assertFalse(access_requests)

        gamma_user = security_manager.find_user(username='******')
        gamma_user.roles.remove(security_manager.find_role('Alpha'))
        session.commit()
Example #3
0
def load_test_users_run():
    """
    Loads admin, alpha, and gamma user for testing purposes

    Syncs permissions for those users/roles
    """
    if config.get('TESTING'):
        security_manager.sync_role_definitions()
        gamma_sqllab_role = security_manager.add_role('gamma_sqllab')
        for perm in security_manager.find_role('Gamma').permissions:
            security_manager.add_permission_role(gamma_sqllab_role, perm)
        utils.get_or_create_main_db()
        db_perm = utils.get_main_database(security_manager.get_session).perm
        security_manager.merge_perm('database_access', db_perm)
        db_pvm = security_manager.find_permission_view_menu(
            view_menu_name=db_perm, permission_name='database_access')
        gamma_sqllab_role.permissions.append(db_pvm)
        for perm in security_manager.find_role('sql_lab').permissions:
            security_manager.add_permission_role(gamma_sqllab_role, perm)

        admin = security_manager.find_user('admin')
        if not admin:
            security_manager.add_user(
                'admin', 'admin', ' user', '*****@*****.**',
                security_manager.find_role('Admin'),
                password='******')

        gamma = security_manager.find_user('gamma')
        if not gamma:
            security_manager.add_user(
                'gamma', 'gamma', 'user', '*****@*****.**',
                security_manager.find_role('Gamma'),
                password='******')

        gamma2 = security_manager.find_user('gamma2')
        if not gamma2:
            security_manager.add_user(
                'gamma2', 'gamma2', 'user', '*****@*****.**',
                security_manager.find_role('Gamma'),
                password='******')

        gamma_sqllab_user = security_manager.find_user('gamma_sqllab')
        if not gamma_sqllab_user:
            security_manager.add_user(
                'gamma_sqllab', 'gamma_sqllab', 'user', '*****@*****.**',
                gamma_sqllab_role, password='******')

        alpha = security_manager.find_user('alpha')
        if not alpha:
            security_manager.add_user(
                'alpha', 'alpha', 'user', '*****@*****.**',
                security_manager.find_role('Alpha'),
                password='******')
        security_manager.get_session.commit()
    def test_sql_json_has_access(self):
        main_db = self.get_main_database(db.session)
        security_manager.add_permission_view_menu('database_access', main_db.perm)
        db.session.commit()
        main_db_permission_view = (
            db.session.query(ab_models.PermissionView)
            .join(ab_models.ViewMenu)
            .join(ab_models.Permission)
            .filter(ab_models.ViewMenu.name == '[main].(id:1)')
            .filter(ab_models.Permission.name == 'database_access')
            .first()
        )
        astronaut = security_manager.add_role('Astronaut')
        security_manager.add_permission_role(astronaut, main_db_permission_view)
        # Astronaut role is Gamma + sqllab +  main db permissions
        for perm in security_manager.find_role('Gamma').permissions:
            security_manager.add_permission_role(astronaut, perm)
        for perm in security_manager.find_role('sql_lab').permissions:
            security_manager.add_permission_role(astronaut, perm)

        gagarin = security_manager.find_user('gagarin')
        if not gagarin:
            security_manager.add_user(
                'gagarin', 'Iurii', 'Gagarin', '*****@*****.**',
                astronaut,
                password='******')
        data = self.run_sql('SELECT * FROM ab_user', '3', user_name='gagarin')
        db.session.query(Query).delete()
        db.session.commit()
        self.assertLess(0, len(data['data']))
    def test_owners_can_view_empty_dashboard(self):
        dash = (
            db.session
            .query(models.Dashboard)
            .filter_by(slug='empty_dashboard')
            .first()
        )
        if not dash:
            dash = models.Dashboard()
            dash.dashboard_title = 'Empty Dashboard'
            dash.slug = 'empty_dashboard'
        else:
            dash.slices = []
            dash.owners = []
        db.session.merge(dash)
        db.session.commit()

        gamma_user = security_manager.find_user('gamma')
        self.login(gamma_user.username)

        resp = self.get_resp('/dashboardmodelview/list/')
        self.assertNotIn('/superset/dashboard/empty_dashboard/', resp)

        dash = (
            db.session
            .query(models.Dashboard)
            .filter_by(slug='empty_dashboard')
            .first()
        )
        dash.owners = [gamma_user]
        db.session.merge(dash)
        db.session.commit()

        resp = self.get_resp('/dashboardmodelview/list/')
        self.assertIn('/superset/dashboard/empty_dashboard/', resp)
    def test_only_owners_can_save(self):
        dash = (
            db.session
            .query(models.Dashboard)
            .filter_by(slug='births')
            .first()
        )
        dash.owners = []
        db.session.merge(dash)
        db.session.commit()
        self.test_save_dash('admin')

        self.logout()
        self.assertRaises(
            Exception, self.test_save_dash, 'alpha')

        alpha = security_manager.find_user('alpha')

        dash = (
            db.session
            .query(models.Dashboard)
            .filter_by(slug='births')
            .first()
        )
        dash.owners = [alpha]
        db.session.merge(dash)
        db.session.commit()
        self.test_save_dash('alpha')
Example #7
0
    def test_clean_requests_after_role_extend(self):
        session = db.session

        # Case 1. Gamma and gamma2 requested test_role1 on energy_usage access
        # Gamma already has role test_role1
        # Extend test_role1 with energy_usage access for gamma2
        # Check if access request for gamma at energy_usage was deleted

        # gamma2 and gamma request table_role on energy usage
        if app.config.get('ENABLE_ACCESS_REQUEST'):
            access_request1 = create_access_request(
                session, 'table', 'random_time_series', TEST_ROLE_1, 'gamma2')
            ds_1_id = access_request1.datasource_id
            create_access_request(
                session, 'table', 'random_time_series', TEST_ROLE_1, 'gamma')
            access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
            self.assertTrue(access_requests)
            # gamma gets test_role1
            self.get_resp(GRANT_ROLE_REQUEST.format(
                'table', ds_1_id, 'gamma', TEST_ROLE_1))
            # extend test_role1 with access on energy usage
            self.client.get(EXTEND_ROLE_REQUEST.format(
                'table', ds_1_id, 'gamma2', TEST_ROLE_1))
            access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
            self.assertFalse(access_requests)

            gamma_user = security_manager.find_user(username='******')
            gamma_user.roles.remove(security_manager.find_role('test_role1'))
    def test_user_profile(self, username='******'):
        self.login(username=username)
        slc = self.get_slice('Girls', db.session)

        # Setting some faves
        url = '/superset/favstar/Slice/{}/select/'.format(slc.id)
        resp = self.get_json_resp(url)
        self.assertEqual(resp['count'], 1)

        dash = (
            db.session
            .query(models.Dashboard)
            .filter_by(slug='births')
            .first()
        )
        url = '/superset/favstar/Dashboard/{}/select/'.format(dash.id)
        resp = self.get_json_resp(url)
        self.assertEqual(resp['count'], 1)

        userid = security_manager.find_user('admin').id
        resp = self.get_resp('/superset/profile/admin/')
        self.assertIn('"app"', resp)
        data = self.get_json_resp('/superset/recent_activity/{}/'.format(userid))
        self.assertNotIn('message', data)
        data = self.get_json_resp('/superset/created_slices/{}/'.format(userid))
        self.assertNotIn('message', data)
        data = self.get_json_resp('/superset/created_dashboards/{}/'.format(userid))
        self.assertNotIn('message', data)
        data = self.get_json_resp('/superset/fave_slices/{}/'.format(userid))
        self.assertNotIn('message', data)
        data = self.get_json_resp('/superset/fave_dashboards/{}/'.format(userid))
        self.assertNotIn('message', data)
        data = self.get_json_resp(
            '/superset/fave_dashboards_by_username/{}/'.format(username))
        self.assertNotIn('message', data)
    def test_search_query_on_user(self):
        self.run_some_queries()
        self.login('admin')

        # Test search queries on user Id
        user_id = security_manager.find_user('admin').id
        data = self.get_json_resp(
            '/superset/search_queries?user_id={}'.format(user_id))
        self.assertEquals(2, len(data))
        user_ids = {k['userId'] for k in data}
        self.assertEquals(set([user_id]), user_ids)

        user_id = security_manager.find_user('gamma_sqllab').id
        resp = self.get_resp(
            '/superset/search_queries?user_id={}'.format(user_id))
        data = json.loads(resp)
        self.assertEquals(1, len(data))
        self.assertEquals(data[0]['userId'], user_id)
Example #10
0
    def test_queryview_filter(self) -> None:
        """
        Test queryview api without can_only_access_owned_queries perm added to
        Admin and make sure all queries show up.
        """
        self.run_some_queries()
        self.login(username='******')

        url = '/queryview/api/read'
        data = self.get_json_resp(url)
        admin = security_manager.find_user('admin')
        gamma_sqllab = security_manager.find_user('gamma_sqllab')
        self.assertEquals(3, len(data['result']))
        user_queries = [
            result.get('username') for result in data['result']
        ]
        assert admin.username in user_queries
        assert gamma_sqllab.username in user_queries
    def test_dashboard_with_created_by_can_be_accessed_by_public_users(self):
        self.logout()
        table = (
            db.session
            .query(SqlaTable)
            .filter_by(table_name='birth_names')
            .one()
        )
        self.grant_public_access_to_table(table)

        dash = db.session.query(models.Dashboard).filter_by(
            slug='births').first()
        dash.owners = [security_manager.find_user('admin')]
        dash.created_by = security_manager.find_user('admin')
        db.session.merge(dash)
        db.session.commit()

        assert 'Births' in self.get_resp('/superset/dashboard/births/')
Example #12
0
 def get_access_requests(self, username, ds_type, ds_id):
     DAR = models.DatasourceAccessRequest
     return (
         db.session.query(DAR)
         .filter(
             DAR.created_by == security_manager.find_user(username=username),
             DAR.datasource_type == ds_type,
             DAR.datasource_id == ds_id,
         )
         .first()
     )
Example #13
0
    def test_clean_requests_after_schema_grant(self):
        session = db.session

        # Case 4. Two access requests from gamma and gamma2
        # Gamma gets schema access, gamma2 access request granted
        # Check if request by gamma has been deleted

        gamma_user = security_manager.find_user(username='******')
        access_request1 = create_access_request(
            session, 'table', 'wb_health_population', TEST_ROLE_1, 'gamma')
        create_access_request(
            session, 'table', 'wb_health_population', TEST_ROLE_2, 'gamma2')
        ds_1_id = access_request1.datasource_id
        ds = session.query(SqlaTable).filter_by(
            table_name='wb_health_population').first()

        ds.schema = 'temp_schema'
        security_manager.merge_perm('schema_access', ds.schema_perm)
        schema_perm_view = security_manager.find_permission_view_menu(
            'schema_access', ds.schema_perm)
        security_manager.add_permission_role(
            security_manager.find_role(SCHEMA_ACCESS_ROLE), schema_perm_view)
        gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
        session.commit()
        # gamma2 request gets fulfilled
        self.client.get(EXTEND_ROLE_REQUEST.format(
            'table', ds_1_id, 'gamma2', TEST_ROLE_2))
        access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
        self.assertFalse(access_requests)
        gamma_user = security_manager.find_user(username='******')
        gamma_user.roles.remove(security_manager.find_role(SCHEMA_ACCESS_ROLE))

        ds = session.query(SqlaTable).filter_by(
            table_name='wb_health_population').first()
        ds.schema = None

        session.commit()
Example #14
0
def _get_auth_cookies():
    # Login with the user specified to get the reports
    with app.test_request_context():
        user = security_manager.find_user(config.get('EMAIL_REPORTS_USER'))
        login_user(user)

        # A mock response object to get the cookie information from
        response = Response()
        app.session_interface.save_session(app, session, response)

    cookies = []

    # Set the cookies in the driver
    for name, value in response.headers:
        if name.lower() == 'set-cookie':
            cookie = parse_cookie(value)
            cookies.append(cookie['session'])

    return cookies
Example #15
0
    def test_queryview_filter_owner_only(self) -> None:
        """
        Test queryview api with can_only_access_owned_queries perm added to
        Admin and make sure only Admin queries show up.
        """
        session = db.session

        # Add can_only_access_owned_queries perm to Admin user
        owned_queries_view = security_manager.find_permission_view_menu(
            'can_only_access_owned_queries',
            'can_only_access_owned_queries',
        )
        security_manager.add_permission_role(
            security_manager.find_role('Admin'),
            owned_queries_view,
        )
        session.commit()

        # Test search_queries for Admin user
        self.run_some_queries()
        self.login('admin')

        url = '/queryview/api/read'
        data = self.get_json_resp(url)
        admin = security_manager.find_user('admin')
        self.assertEquals(2, len(data['result']))
        all_admin_user_queries = all([
            result.get('username') == admin.username for result in data['result']
        ])
        assert all_admin_user_queries is True

        # Remove can_only_access_owned_queries from Admin
        owned_queries_view = security_manager.find_permission_view_menu(
            'can_only_access_owned_queries',
            'can_only_access_owned_queries',
        )
        security_manager.del_permission_role(
            security_manager.find_role('Admin'),
            owned_queries_view,
        )

        session.commit()
Example #16
0
def create_access_request(session, ds_type, ds_name, role_name, user_name):
    ds_class = ConnectorRegistry.sources[ds_type]
    # TODO: generalize datasource names
    if ds_type == 'table':
        ds = session.query(ds_class).filter(
            ds_class.table_name == ds_name).first()
    else:
        ds = session.query(ds_class).filter(
            ds_class.datasource_name == ds_name).first()
    ds_perm_view = security_manager.find_permission_view_menu(
        'datasource_access', ds.perm)
    security_manager.add_permission_role(
        security_manager.find_role(role_name), ds_perm_view)
    access_request = models.DatasourceAccessRequest(
        datasource_id=ds.id,
        datasource_type=ds_type,
        created_by_fk=security_manager.find_user(username=user_name).id,
    )
    session.add(access_request)
    session.commit()
    return access_request
Example #17
0
    def test_export_chart_with_query_context(self, mock_g):
        """Test that charts that have a query_context are exported correctly"""

        mock_g.user = security_manager.find_user("alpha")
        example_chart = db.session.query(Slice).filter_by(slice_name="Heatmap").one()
        command = ExportChartsCommand([example_chart.id])

        contents = dict(command.run())

        expected = [
            "metadata.yaml",
            f"charts/Heatmap_{example_chart.id}.yaml",
            "datasets/examples/energy_usage.yaml",
            "databases/examples.yaml",
        ]
        assert expected == list(contents.keys())

        metadata = yaml.safe_load(contents[f"charts/Heatmap_{example_chart.id}.yaml"])

        assert metadata == {
            "slice_name": "Heatmap",
            "viz_type": "heatmap",
            "params": {
                "all_columns_x": "source",
                "all_columns_y": "target",
                "canvas_image_rendering": "pixelated",
                "collapsed_fieldsets": "",
                "linear_color_scheme": "blue_white_yellow",
                "metric": "sum__value",
                "normalize_across": "heatmap",
                "slice_name": "Heatmap",
                "viz_type": "heatmap",
                "xscale_interval": "1",
                "yscale_interval": "1",
            },
            "cache_timeout": None,
            "dataset_uuid": str(example_chart.table.uuid),
            "uuid": str(example_chart.uuid),
            "version": "1.0.0",
        }
Example #18
0
    def test_sqllab_backend_persistence_payload(self):
        username = "******"
        self.login(username)
        user_id = security_manager.find_user(username).id

        # create a tab
        data = {
            "queryEditor":
            json.dumps({
                "title": "Untitled Query 1",
                "dbId": 1,
                "schema": None,
                "autorun": False,
                "sql": "SELECT ...",
                "queryLimit": 1000,
            })
        }
        resp = self.get_json_resp("/tabstateview/", data=data)
        tab_state_id = resp["id"]

        # run a query in the created tab
        self.run_sql(
            "SELECT name FROM birth_names",
            "client_id_1",
            user_name=username,
            raise_on_error=True,
            sql_editor_id=tab_state_id,
        )
        # run an orphan query (no tab)
        self.run_sql(
            "SELECT name FROM birth_names",
            "client_id_2",
            user_name=username,
            raise_on_error=True,
        )

        # we should have only 1 query returned, since the second one is not
        # associated with any tabs
        payload = views.Superset._get_sqllab_tabs(user_id=user_id)
        self.assertEqual(len(payload["queries"]), 1)
Example #19
0
    def import_dashboards(path: str, username: Optional[str]) -> None:
        """Import dashboards from ZIP file"""
        # pylint: disable=import-outside-toplevel
        from superset.commands.importers.v1.utils import get_contents_from_bundle
        from superset.dashboards.commands.importers.dispatcher import (
            ImportDashboardsCommand, )

        if username is not None:
            g.user = security_manager.find_user(username=username)
        if is_zipfile(path):
            with ZipFile(path) as bundle:
                contents = get_contents_from_bundle(bundle)
        else:
            with open(path) as file:
                contents = {path: file.read()}
        try:
            ImportDashboardsCommand(contents, overwrite=True).run()
        except Exception:  # pylint: disable=broad-except
            logger.exception(
                "There was an error when importing the dashboards(s), please check "
                "the exception traceback in the log")
            sys.exit(1)
Example #20
0
    def test_search_query_with_owner_only_perms(self) -> None:
        """
        Test a search query with can_only_access_owned_queries perm added to
        Admin and make sure only Admin queries show up.
        """
        session = db.session

        # Add can_only_access_owned_queries perm to Admin user
        owned_queries_view = security_manager.find_permission_view_menu(
            'can_only_access_owned_queries',
            'can_only_access_owned_queries',
        )
        security_manager.add_permission_role(
            security_manager.find_role('Admin'),
            owned_queries_view,
        )
        session.commit()

        # Test search_queries for Admin user
        self.run_some_queries()
        self.login('admin')

        user_id = security_manager.find_user('admin').id
        data = self.get_json_resp('/superset/search_queries')
        self.assertEquals(2, len(data))
        user_ids = {k['userId'] for k in data}
        self.assertEquals(set([user_id]), user_ids)

        # Remove can_only_access_owned_queries from Admin
        owned_queries_view = security_manager.find_permission_view_menu(
            'can_only_access_owned_queries',
            'can_only_access_owned_queries',
        )
        security_manager.del_permission_role(
            security_manager.find_role('Admin'),
            owned_queries_view,
        )

        session.commit()
Example #21
0
    def export_datasources(datasource_file: Optional[str]) -> None:
        """Export datasources to ZIP file"""
        from superset.connectors.sqla.models import SqlaTable
        from superset.datasets.commands.export import ExportDatasetsCommand

        g.user = security_manager.find_user(username="******")

        dataset_ids = [id_ for (id_, ) in db.session.query(SqlaTable.id).all()]
        timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
        root = f"dataset_export_{timestamp}"
        datasource_file = datasource_file or f"{root}.zip"

        try:
            with ZipFile(datasource_file, "w") as bundle:
                for file_name, file_content in ExportDatasetsCommand(
                        dataset_ids).run():
                    with bundle.open(f"{root}/{file_name}", "w") as fp:
                        fp.write(file_content.encode())
        except Exception:  # pylint: disable=broad-except
            logger.exception(
                "There was an error when exporting the datasets, please check "
                "the exception traceback in the log")
Example #22
0
    def test_export_database_command_key_order(self, mock_g):
        """Test that they keys in the YAML have the same order as export_fields"""
        mock_g.user = security_manager.find_user("admin")

        example_db = get_example_database()
        command = ExportDatabasesCommand([example_db.id])
        contents = dict(command.run())

        metadata = yaml.safe_load(contents["databases/examples.yaml"])
        assert list(metadata.keys()) == [
            "database_name",
            "sqlalchemy_uri",
            "cache_timeout",
            "expose_in_sqllab",
            "allow_run_async",
            "allow_ctas",
            "allow_cvas",
            "allow_csv_upload",
            "extra",
            "uuid",
            "version",
        ]
    def test_get_user_datasources_admin(self, mock_get_session,
                                        mock_can_access_database, mock_g):
        Datasource = namedtuple("Datasource", ["database", "schema", "name"])
        mock_g.user = security_manager.find_user("admin")
        mock_can_access_database.return_value = True
        mock_get_session.query.return_value.filter.return_value.all.return_value = []

        with mock.patch.object(
                SqlaTable, "get_all_datasources") as mock_get_all_datasources:
            mock_get_all_datasources.return_value = [
                Datasource("database1", "schema1", "table1"),
                Datasource("database1", "schema1", "table2"),
                Datasource("database2", None, "table1"),
            ]

            datasources = security_manager.get_user_datasources()

        assert sorted(datasources) == [
            Datasource("database1", "schema1", "table1"),
            Datasource("database1", "schema1", "table2"),
            Datasource("database2", None, "table1"),
        ]
Example #24
0
    def test_export_chart_command_key_order(self, mock_g):
        """Test that they keys in the YAML have the same order as export_fields"""
        mock_g.user = security_manager.find_user("admin")

        example_chart = (
            db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
        )
        command = ExportChartsCommand([example_chart.id])
        contents = dict(command.run())

        metadata = yaml.safe_load(
            contents[f"charts/Energy_Sankey_{example_chart.id}.yaml"]
        )
        assert list(metadata.keys()) == [
            "slice_name",
            "viz_type",
            "params",
            "cache_timeout",
            "uuid",
            "version",
            "dataset_uuid",
        ]
Example #25
0
    def test_create_form_data_command_type_as_string(self, mock_g):
        mock_g.user = security_manager.find_user("admin")
        app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
            "REFRESH_TIMEOUT_ON_RETRIEVAL": True
        }

        dataset = (db.session.query(SqlaTable).filter_by(
            table_name="dummy_sql_table").first())
        slice = db.session.query(Slice).filter_by(
            slice_name="slice_name").first()

        datasource = f"{dataset.id}__{DatasourceType.TABLE}"
        create_args = CommandParameters(
            datasource_id=dataset.id,
            datasource_type="table",
            chart_id=slice.id,
            tab_id=1,
            form_data=json.dumps({"datasource": datasource}),
        )
        command = CreateFormDataCommand(create_args)

        assert isinstance(command.run(), str)
Example #26
0
    def import_dashboards(path: str, recursive: bool, username: str) -> None:
        """Import dashboards from JSON file"""
        from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand

        path_object = Path(path)
        files: List[Path] = []
        if path_object.is_file():
            files.append(path_object)
        elif path_object.exists() and not recursive:
            files.extend(path_object.glob("*.json"))
        elif path_object.exists() and recursive:
            files.extend(path_object.rglob("*.json"))
        if username is not None:
            g.user = security_manager.find_user(username=username)
        contents = {}
        for path_ in files:
            with open(path_) as file:
                contents[path_.name] = file.read()
        try:
            ImportDashboardsCommand(contents).run()
        except Exception:  # pylint: disable=broad-except
            logger.exception("Error when importing dashboard")
Example #27
0
    def test_search_query_with_owner_only_perms(self) -> None:
        """
        Test a search query with can_only_access_owned_queries perm added to
        Admin and make sure only Admin queries show up.
        """
        session = db.session

        # Add can_only_access_owned_queries perm to Admin user
        owned_queries_view = security_manager.find_permission_view_menu(
            'can_only_access_owned_queries',
            'can_only_access_owned_queries',
        )
        security_manager.add_permission_role(
            security_manager.find_role('Admin'),
            owned_queries_view,
        )
        session.commit()

        # Test search_queries for Admin user
        self.run_some_queries()
        self.login('admin')

        user_id = security_manager.find_user('admin').id
        data = self.get_json_resp('/superset/search_queries')
        self.assertEquals(2, len(data))
        user_ids = {k['userId'] for k in data}
        self.assertEquals(set([user_id]), user_ids)

        # Remove can_only_access_owned_queries from Admin
        owned_queries_view = security_manager.find_permission_view_menu(
            'can_only_access_owned_queries',
            'can_only_access_owned_queries',
        )
        security_manager.del_permission_role(
            security_manager.find_role('Admin'),
            owned_queries_view,
        )

        session.commit()
Example #28
0
def cache_chart_thumbnail(
    url: str,
    digest: str,
    force: bool = False,
    window_size: Optional[WindowSize] = None,
    thumb_size: Optional[WindowSize] = None,
) -> None:
    with app.app_context():  # type: ignore
        if not thumbnail_cache:
            logger.warning("No cache set, refusing to compute")
            return None
        logger.info("Caching chart: %s", url)
        screenshot = ChartScreenshot(url, digest)
        user = security_manager.find_user(current_app.config["THUMBNAIL_SELENIUM_USER"])
        screenshot.compute_and_cache(
            user=user,
            cache=thumbnail_cache,
            force=force,
            window_size=window_size,
            thumb_size=thumb_size,
        )
        return None
Example #29
0
def import_dashboards(path: str, recursive: bool, username: str) -> None:
    """Import dashboards from JSON"""
    from superset.utils import dashboard_import_export

    path_object = Path(path)
    files = []
    if path_object.is_file():
        files.append(path_object)
    elif path_object.exists() and not recursive:
        files.extend(path_object.glob("*.json"))
    elif path_object.exists() and recursive:
        files.extend(path_object.rglob("*.json"))
    if username is not None:
        g.user = security_manager.find_user(username=username)
    for file_ in files:
        logger.info("Importing dashboard from file %s", file_)
        try:
            with file_.open() as data_stream:
                dashboard_import_export.import_dashboards(data_stream)
        except Exception as ex:  # pylint: disable=broad-except
            logger.error("Error when importing dashboard from file %s", file_)
            logger.error(ex)
    def test_queryview_filter_owner_only(self) -> None:
        """
        Test queryview api with can_only_access_owned_queries perm added to
        Admin and make sure only Admin queries show up.
        """
        session = db.session

        # Add can_only_access_owned_queries perm to Admin user
        owned_queries_view = security_manager.find_permission_view_menu(
            "can_only_access_owned_queries", "can_only_access_owned_queries"
        )
        security_manager.add_permission_role(
            security_manager.find_role("Admin"), owned_queries_view
        )
        session.commit()

        # Test search_queries for Admin user
        self.run_some_queries()
        self.login("admin")

        url = "/queryview/api/read"
        data = self.get_json_resp(url)
        admin = security_manager.find_user("admin")
        self.assertEqual(2, len(data["result"]))
        all_admin_user_queries = all(
            [result.get("username") == admin.username for result in data["result"]]
        )
        assert all_admin_user_queries is True

        # Remove can_only_access_owned_queries from Admin
        owned_queries_view = security_manager.find_permission_view_menu(
            "can_only_access_owned_queries", "can_only_access_owned_queries"
        )
        security_manager.del_permission_role(
            security_manager.find_role("Admin"), owned_queries_view
        )

        session.commit()
Example #31
0
    def test_users_can_view_published_dashboard(self):
        # Create a published and hidden dashboard and add them to the database
        published_dash = models.Dashboard()
        published_dash.dashboard_title = 'Published Dashboard'
        published_dash.slug = 'published_dash'
        published_dash.published = True

        hidden_dash = models.Dashboard()
        hidden_dash.dashboard_title = 'Hidden Dashboard'
        hidden_dash.slug = 'hidden_dash'
        hidden_dash.published = False

        db.session.merge(published_dash)
        db.session.merge(hidden_dash)
        db.session.commit()

        user = security_manager.find_user('alpha')
        self.login(user.username)

        # list dashboards and validate we only see the published dashboard
        resp = self.get_resp('/dashboardmodelview/list/')
        self.assertNotIn('/superset/dashboard/hidden_dash/', resp)
        self.assertIn('/superset/dashboard/published_dash/', resp)
    def setUpClass(cls):
        try:
            os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH'))
        except OSError as e:
            app.logger.warn(str(e))
        try:
            os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH'))
        except OSError as e:
            app.logger.warn(str(e))

        security_manager.sync_role_definitions()

        worker_command = BASE_DIR + '/bin/superset worker'
        subprocess.Popen(
            worker_command, shell=True, stdout=subprocess.PIPE)

        admin = security_manager.find_user('admin')
        if not admin:
            security_manager.add_user(
                'admin', 'admin', ' user', '*****@*****.**',
                security_manager.find_role('Admin'),
                password='******')
        cli.load_examples(load_test_data=True)
Example #33
0
    def test_create_form_data_command_invalid_type(self, mock_g):
        mock_g.user = security_manager.find_user("admin")
        app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
            "REFRESH_TIMEOUT_ON_RETRIEVAL": True
        }

        dataset = (db.session.query(SqlaTable).filter_by(
            table_name="dummy_sql_table").first())
        slice = db.session.query(Slice).filter_by(
            slice_name="slice_name").first()

        datasource = f"{dataset.id}__{DatasourceType.TABLE}"
        create_args = CommandParameters(
            datasource_id=dataset.id,
            datasource_type="InvalidType",
            chart_id=slice.id,
            tab_id=1,
            form_data=json.dumps({"datasource": datasource}),
        )
        with pytest.raises(DatasourceTypeInvalidError) as exc:
            CreateFormDataCommand(create_args).run()

        assert "Datasource type is invalid" in str(exc.value)
Example #34
0
def import_dashboards(path, recursive, username):
    """Import dashboards from JSON"""
    from superset.utils import dashboard_import_export

    p = Path(path)
    files = []
    if p.is_file():
        files.append(p)
    elif p.exists() and not recursive:
        files.extend(p.glob("*.json"))
    elif p.exists() and recursive:
        files.extend(p.rglob("*.json"))
    if username is not None:
        g.user = security_manager.find_user(username=username)
    for f in files:
        logging.info("Importing dashboard from file %s", f)
        try:
            with f.open() as data_stream:
                dashboard_import_export.import_dashboards(
                    db.session, data_stream)
        except Exception as e:
            logging.error("Error when importing dashboard from file %s", f)
            logging.error(e)
Example #35
0
def _get_slice_screenshot(slice_id: int) -> ScreenshotData:
    slice_obj = db.session.query(Slice).get(slice_id)

    chart_url = get_url_path("Superset.slice",
                             slice_id=slice_obj.id,
                             standalone="true")
    screenshot = ChartScreenshot(chart_url, slice_obj.digest)
    image_url = _get_url_path(
        "Superset.slice",
        user_friendly=True,
        slice_id=slice_obj.id,
    )

    user = security_manager.find_user(
        current_app.config["THUMBNAIL_SELENIUM_USER"])
    image_data = screenshot.compute_and_cache(
        user=user,
        cache=thumbnail_cache,
        force=True,
    )

    db.session.commit()
    return ScreenshotData(image_url, image_data)
    def setUpClass(cls):
        try:
            os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH'))
        except OSError as e:
            app.logger.warn(str(e))
        try:
            os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH'))
        except OSError as e:
            app.logger.warn(str(e))

        security_manager.sync_role_definitions()

        worker_command = BASE_DIR + '/bin/superset worker'
        subprocess.Popen(
            worker_command, shell=True, stdout=subprocess.PIPE)

        admin = security_manager.find_user('admin')
        if not admin:
            security_manager.add_user(
                'admin', 'admin', ' user', '*****@*****.**',
                security_manager.find_role('Admin'),
                password='******')
        cli.load_examples_run(load_test_data=True)
Example #37
0
    def test_export_query_command(self, mock_g):
        mock_g.user = security_manager.find_user("admin")

        command = ExportSavedQueriesCommand([self.example_query.id])
        contents = dict(command.run())

        expected = [
            "metadata.yaml",
            "queries/examples/schema1/The_answer.yaml",
            "databases/examples.yaml",
        ]
        assert expected == list(contents.keys())

        metadata = yaml.safe_load(contents["queries/examples/schema1/The_answer.yaml"])
        assert metadata == {
            "schema": "schema1",
            "label": "The answer",
            "description": "Answer to the Ultimate Question of Life, the Universe, and Everything",
            "sql": "SELECT 42",
            "uuid": str(self.example_query.uuid),
            "version": "1.0.0",
            "database_uuid": str(self.example_database.uuid),
        }
Example #38
0
    def test_export_chart_command(self, mock_g):
        mock_g.user = security_manager.find_user("admin")

        example_chart = (
            db.session.query(Slice).filter_by(slice_name="Energy Sankey").one()
        )
        command = ExportChartsCommand([example_chart.id])
        contents = dict(command.run())

        expected = [
            "metadata.yaml",
            f"charts/Energy_Sankey_{example_chart.id}.yaml",
            "datasets/examples/energy_usage.yaml",
            "databases/examples.yaml",
        ]
        assert expected == list(contents.keys())

        metadata = yaml.safe_load(
            contents[f"charts/Energy_Sankey_{example_chart.id}.yaml"]
        )
        assert metadata == {
            "slice_name": "Energy Sankey",
            "viz_type": "sankey",
            "params": {
                "collapsed_fieldsets": "",
                "groupby": ["source", "target"],
                "metric": "sum__value",
                "row_limit": "5000",
                "slice_name": "Energy Sankey",
                "viz_type": "sankey",
            },
            "query_context": None,
            "cache_timeout": None,
            "dataset_uuid": str(example_chart.table.uuid),
            "uuid": str(example_chart.uuid),
            "version": "1.0.0",
        }
 def test_create_v1_response(self, mock_sm_g, mock_c_g, mock_u_g):
     """Test that the create chart command creates a chart"""
     user = security_manager.find_user(username="******")
     mock_u_g.user = mock_c_g.user = mock_sm_g.user = user
     chart_data = {
         "slice_name": "new chart",
         "description": "new description",
         "owners": [user.id],
         "viz_type": "new_viz_type",
         "params": json.dumps({"viz_type": "new_viz_type"}),
         "cache_timeout": 1000,
         "datasource_id": 1,
         "datasource_type": "table",
     }
     command = CreateChartCommand(chart_data)
     chart = command.run()
     chart = db.session.query(Slice).get(chart.id)
     assert chart.viz_type == "new_viz_type"
     json_params = json.loads(chart.params)
     assert json_params == {"viz_type": "new_viz_type"}
     assert chart.slice_name == "new chart"
     assert chart.owners == [user]
     db.session.delete(chart)
     db.session.commit()
Example #40
0
    def _execute_query(self) -> pd.DataFrame:
        """
        Executes the actual alert SQL query template

        :return: A pandas dataframe
        :raises AlertQueryError: SQL query is not valid
        :raises AlertQueryTimeout: The SQL query received a celery soft timeout
        """
        sql_template = jinja_context.get_template_processor(
            database=self._report_schedule.database
        )
        rendered_sql = sql_template.process_template(self._report_schedule.sql)
        try:
            limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql(
                rendered_sql, ALERT_SQL_LIMIT
            )

            with override_user(
                security_manager.find_user(
                    username=app.config["THUMBNAIL_SELENIUM_USER"]
                )
            ):
                start = default_timer()
                df = self._report_schedule.database.get_df(sql=limited_rendered_sql)
                stop = default_timer()
                logger.info(
                    "Query for %s took %.2f ms",
                    self._report_schedule.name,
                    (stop - start) * 1000.0,
                )
                return df
        except SoftTimeLimitExceeded as ex:
            logger.warning("A timeout occurred while executing the alert query: %s", ex)
            raise AlertQueryTimeout() from ex
        except Exception as ex:
            raise AlertQueryError(message=str(ex)) from ex
Example #41
0
    def test_context_manager_log(self, mock_g):
        class DummyEventLogger(AbstractEventLogger):
            def __init__(self):
                self.records = []

            def log(
                self,
                user_id: Optional[int],
                action: str,
                dashboard_id: Optional[int],
                duration_ms: Optional[int],
                slice_id: Optional[int],
                referrer: Optional[str],
                *args: Any,
                **kwargs: Any,
            ):
                self.records.append({
                    **kwargs, "user_id": user_id,
                    "duration": duration_ms
                })

        logger = DummyEventLogger()

        with app.test_request_context():
            mock_g.user = security_manager.find_user("gamma")
            with logger(action="foo", engine="bar"):
                pass

        assert logger.records == [{
            "records": [{
                "path": "/",
                "engine": "bar"
            }],
            "user_id": 2,
            "duration": 15000.0,
        }]
Example #42
0
    def test_get_permalink_command(self, mock_g):
        mock_g.user = security_manager.find_user("admin")
        app.config["EXPLORE_FORM_DATA_CACHE_CONFIG"] = {
            "REFRESH_TIMEOUT_ON_RETRIEVAL": True
        }

        dataset = (db.session.query(SqlaTable).filter_by(
            table_name="dummy_sql_table").first())
        slice = db.session.query(Slice).filter_by(
            slice_name="slice_name").first()

        datasource = f"{dataset.id}__{DatasourceType.TABLE}"

        key = CreateExplorePermalinkCommand({
            "formData": {
                "datasource": datasource,
                "slice_id": slice.id
            }
        }).run()

        get_command = GetExplorePermalinkCommand(key)
        cache_data = get_command.run()

        assert cache_data.get("datasource") == datasource
Example #43
0
    def test_export_dataset_command(self, mock_g):
        mock_g.user = security_manager.find_user("admin")

        example_db = get_example_database()
        example_dataset = _get_table_from_list_by_name(
            "energy_usage", example_db.tables
        )
        command = ExportDatasetsCommand([example_dataset.id])
        contents = dict(command.run())

        assert list(contents.keys()) == [
            "metadata.yaml",
            "datasets/examples/energy_usage.yaml",
            "databases/examples.yaml",
        ]

        metadata = yaml.safe_load(contents["datasets/examples/energy_usage.yaml"])

        # sort columns for deterministc comparison
        metadata["columns"] = sorted(metadata["columns"], key=itemgetter("column_name"))
        metadata["metrics"] = sorted(metadata["metrics"], key=itemgetter("metric_name"))

        # types are different depending on the backend
        type_map = {
            column.column_name: str(column.type) for column in example_dataset.columns
        }

        assert metadata == {
            "cache_timeout": None,
            "columns": [
                {
                    "column_name": "source",
                    "description": None,
                    "expression": "",
                    "filterable": True,
                    "groupby": True,
                    "is_active": True,
                    "is_dttm": False,
                    "python_date_format": None,
                    "type": type_map["source"],
                    "verbose_name": None,
                },
                {
                    "column_name": "target",
                    "description": None,
                    "expression": "",
                    "filterable": True,
                    "groupby": True,
                    "is_active": True,
                    "is_dttm": False,
                    "python_date_format": None,
                    "type": type_map["target"],
                    "verbose_name": None,
                },
                {
                    "column_name": "value",
                    "description": None,
                    "expression": "",
                    "filterable": True,
                    "groupby": True,
                    "is_active": True,
                    "is_dttm": False,
                    "python_date_format": None,
                    "type": type_map["value"],
                    "verbose_name": None,
                },
            ],
            "database_uuid": str(example_db.uuid),
            "default_endpoint": None,
            "description": "Energy consumption",
            "extra": None,
            "fetch_values_predicate": None,
            "filter_select_enabled": False,
            "main_dttm_col": None,
            "metrics": [
                {
                    "d3format": None,
                    "description": None,
                    "expression": "COUNT(*)",
                    "extra": None,
                    "metric_name": "count",
                    "metric_type": "count",
                    "verbose_name": "COUNT(*)",
                    "warning_text": None,
                },
                {
                    "d3format": None,
                    "description": None,
                    "expression": "SUM(value)",
                    "extra": None,
                    "metric_name": "sum__value",
                    "metric_type": None,
                    "verbose_name": None,
                    "warning_text": None,
                },
            ],
            "offset": 0,
            "params": None,
            "schema": None,
            "sql": None,
            "table_name": "energy_usage",
            "template_params": None,
            "uuid": str(example_dataset.uuid),
            "version": "1.0.0",
        }
Example #44
0
def load_birth_names(only_metadata: bool = False,
                     force: bool = False,
                     sample: bool = False) -> None:
    """Loading birth name dataset from a zip file in the repo"""
    # pylint: disable=too-many-locals
    tbl_name = "birth_names"
    database = get_example_database()
    table_exists = database.has_table_by_name(tbl_name)

    if not only_metadata and (not table_exists or force):
        load_data(tbl_name, database, sample=sample)

    obj = db.session.query(TBL).filter_by(table_name=tbl_name).first()
    if not obj:
        print(f"Creating table [{tbl_name}] reference")
        obj = TBL(table_name=tbl_name)
        db.session.add(obj)
    obj.main_dttm_col = "ds"
    obj.database = database
    obj.filter_select_enabled = True

    if not any(col.column_name == "num_california" for col in obj.columns):
        col_state = str(column("state").compile(db.engine))
        col_num = str(column("num").compile(db.engine))
        obj.columns.append(
            TableColumn(
                column_name="num_california",
                expression=
                f"CASE WHEN {col_state} = 'CA' THEN {col_num} ELSE 0 END",
            ))

    if not any(col.metric_name == "sum__num" for col in obj.metrics):
        col = str(column("num").compile(db.engine))
        obj.metrics.append(
            SqlMetric(metric_name="sum__num", expression=f"SUM({col})"))

    db.session.commit()
    obj.fetch_metadata()
    tbl = obj

    metrics = [{
        "expressionType": "SIMPLE",
        "column": {
            "column_name": "num",
            "type": "BIGINT"
        },
        "aggregate": "SUM",
        "label": "Births",
        "optionName": "metric_11",
    }]
    metric = "sum__num"

    defaults = {
        "compare_lag": "10",
        "compare_suffix": "o10Y",
        "limit": "25",
        "granularity_sqla": "ds",
        "groupby": [],
        "row_limit": config["ROW_LIMIT"],
        "since": "100 years ago",
        "until": "now",
        "viz_type": "table",
        "markup_type": "markdown",
    }

    admin = security_manager.find_user("admin")

    print("Creating some slices")
    slices = [
        Slice(
            slice_name="Participants",
            viz_type="big_number",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type="big_number",
                granularity_sqla="ds",
                compare_lag="5",
                compare_suffix="over 5Y",
                metric=metric,
            ),
        ),
        Slice(
            slice_name="Genders",
            viz_type="pie",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(defaults,
                                  viz_type="pie",
                                  groupby=["gender"],
                                  metric=metric),
        ),
        Slice(
            slice_name="Trends",
            viz_type="line",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type="line",
                groupby=["name"],
                granularity_sqla="ds",
                rich_tooltip=True,
                show_legend=True,
                metrics=metrics,
            ),
        ),
        Slice(
            slice_name="Genders by State",
            viz_type="dist_bar",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                adhoc_filters=[{
                    "clause": "WHERE",
                    "expressionType": "SIMPLE",
                    "filterOptionName": "2745eae5",
                    "comparator": ["other"],
                    "operator": "NOT IN",
                    "subject": "state",
                }],
                viz_type="dist_bar",
                metrics=[
                    {
                        "expressionType": "SIMPLE",
                        "column": {
                            "column_name": "sum_boys",
                            "type": "BIGINT(20)"
                        },
                        "aggregate": "SUM",
                        "label": "Boys",
                        "optionName": "metric_11",
                    },
                    {
                        "expressionType": "SIMPLE",
                        "column": {
                            "column_name": "sum_girls",
                            "type": "BIGINT(20)"
                        },
                        "aggregate": "SUM",
                        "label": "Girls",
                        "optionName": "metric_12",
                    },
                ],
                groupby=["state"],
            ),
        ),
        Slice(
            slice_name="Girls",
            viz_type="table",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                groupby=["name"],
                adhoc_filters=[gen_filter("gender", "girl")],
                row_limit=50,
                timeseries_limit_metric="sum__num",
                metrics=metrics,
            ),
        ),
        Slice(
            slice_name="Girl Name Cloud",
            viz_type="word_cloud",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type="word_cloud",
                size_from="10",
                series="name",
                size_to="70",
                rotation="square",
                limit="100",
                adhoc_filters=[gen_filter("gender", "girl")],
                metric=metric,
            ),
        ),
        Slice(
            slice_name="Boys",
            viz_type="table",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                groupby=["name"],
                adhoc_filters=[gen_filter("gender", "boy")],
                row_limit=50,
                metrics=metrics,
            ),
        ),
        Slice(
            slice_name="Boy Name Cloud",
            viz_type="word_cloud",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type="word_cloud",
                size_from="10",
                series="name",
                size_to="70",
                rotation="square",
                limit="100",
                adhoc_filters=[gen_filter("gender", "boy")],
                metric=metric,
            ),
        ),
        Slice(
            slice_name="Top 10 Girl Name Share",
            viz_type="area",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                adhoc_filters=[gen_filter("gender", "girl")],
                comparison_type="values",
                groupby=["name"],
                limit=10,
                stacked_style="expand",
                time_grain_sqla="P1D",
                viz_type="area",
                x_axis_forma="smart_date",
                metrics=metrics,
            ),
        ),
        Slice(
            slice_name="Top 10 Boy Name Share",
            viz_type="area",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                adhoc_filters=[gen_filter("gender", "boy")],
                comparison_type="values",
                groupby=["name"],
                limit=10,
                stacked_style="expand",
                time_grain_sqla="P1D",
                viz_type="area",
                x_axis_forma="smart_date",
                metrics=metrics,
            ),
        ),
    ]
    misc_slices = [
        Slice(
            slice_name="Average and Sum Trends",
            viz_type="dual_line",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type="dual_line",
                metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num",
                        "type": "BIGINT(20)"
                    },
                    "aggregate": "AVG",
                    "label": "AVG(num)",
                    "optionName": "metric_vgops097wej_g8uff99zhk7",
                },
                metric_2="sum__num",
                granularity_sqla="ds",
                metrics=metrics,
            ),
        ),
        Slice(
            slice_name="Num Births Trend",
            viz_type="line",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(defaults, viz_type="line", metrics=metrics),
        ),
        Slice(
            slice_name="Daily Totals",
            viz_type="table",
            datasource_type="table",
            datasource_id=tbl.id,
            created_by=admin,
            params=get_slice_json(
                defaults,
                groupby=["ds"],
                since="40 years ago",
                until="now",
                viz_type="table",
                metrics=metrics,
            ),
        ),
        Slice(
            slice_name="Number of California Births",
            viz_type="big_number_total",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression":
                        "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                },
                viz_type="big_number_total",
                granularity_sqla="ds",
            ),
        ),
        Slice(
            slice_name="Top 10 California Names Timeseries",
            viz_type="line",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                metrics=[{
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression":
                        "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                }],
                viz_type="line",
                granularity_sqla="ds",
                groupby=["name"],
                timeseries_limit_metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression":
                        "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                },
                limit="10",
            ),
        ),
        Slice(
            slice_name="Names Sorted by Num in California",
            viz_type="table",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                metrics=metrics,
                groupby=["name"],
                row_limit=50,
                timeseries_limit_metric={
                    "expressionType": "SIMPLE",
                    "column": {
                        "column_name": "num_california",
                        "expression":
                        "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    "aggregate": "SUM",
                    "label": "SUM(num_california)",
                },
            ),
        ),
        Slice(
            slice_name="Number of Girls",
            viz_type="big_number_total",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                metric=metric,
                viz_type="big_number_total",
                granularity_sqla="ds",
                adhoc_filters=[gen_filter("gender", "girl")],
                subheader="total female participants",
            ),
        ),
        Slice(
            slice_name="Pivot Table",
            viz_type="pivot_table",
            datasource_type="table",
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type="pivot_table",
                groupby=["name"],
                columns=["state"],
                metrics=metrics,
            ),
        ),
    ]
    for slc in slices:
        merge_slice(slc)

    for slc in misc_slices:
        merge_slice(slc)
        misc_dash_slices.add(slc.slice_name)

    print("Creating a dashboard")
    dash = db.session.query(Dashboard).filter_by(slug="births").first()

    if not dash:
        dash = Dashboard()
        db.session.add(dash)
    dash.published = True
    dash.json_metadata = textwrap.dedent("""\
    {
        "label_colors": {
            "Girls": "#FF69B4",
            "Boys": "#ADD8E6",
            "girl": "#FF69B4",
            "boy": "#ADD8E6"
        }
    }""")
    js = textwrap.dedent(
        # pylint: disable=line-too-long
        """\
        {
          "CHART-6GdlekVise": {
            "children": [],
            "id": "CHART-6GdlekVise",
            "meta": {
              "chartId": 5547,
              "height": 50,
              "sliceName": "Top 10 Girl Name Share",
              "width": 5
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-eh0w37bWbR"
            ],
            "type": "CHART"
          },
          "CHART-6n9jxb30JG": {
            "children": [],
            "id": "CHART-6n9jxb30JG",
            "meta": {
              "chartId": 5540,
              "height": 36,
              "sliceName": "Genders by State",
              "width": 5
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW--EyBZQlDi"
            ],
            "type": "CHART"
          },
          "CHART-Jj9qh1ol-N": {
            "children": [],
            "id": "CHART-Jj9qh1ol-N",
            "meta": {
              "chartId": 5545,
              "height": 50,
              "sliceName": "Boy Name Cloud",
              "width": 4
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-kzWtcvo8R1"
            ],
            "type": "CHART"
          },
          "CHART-ODvantb_bF": {
            "children": [],
            "id": "CHART-ODvantb_bF",
            "meta": {
              "chartId": 5548,
              "height": 50,
              "sliceName": "Top 10 Boy Name Share",
              "width": 5
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-kzWtcvo8R1"
            ],
            "type": "CHART"
          },
          "CHART-PAXUUqwmX9": {
            "children": [],
            "id": "CHART-PAXUUqwmX9",
            "meta": {
              "chartId": 5538,
              "height": 34,
              "sliceName": "Genders",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-2n0XgiHDgs"
            ],
            "type": "CHART"
          },
          "CHART-_T6n_K9iQN": {
            "children": [],
            "id": "CHART-_T6n_K9iQN",
            "meta": {
              "chartId": 5539,
              "height": 36,
              "sliceName": "Trends",
              "width": 7
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW--EyBZQlDi"
            ],
            "type": "CHART"
          },
          "CHART-eNY0tcE_ic": {
            "children": [],
            "id": "CHART-eNY0tcE_ic",
            "meta": {
              "chartId": 5537,
              "height": 34,
              "sliceName": "Participants",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-2n0XgiHDgs"
            ],
            "type": "CHART"
          },
          "CHART-g075mMgyYb": {
            "children": [],
            "id": "CHART-g075mMgyYb",
            "meta": {
              "chartId": 5541,
              "height": 50,
              "sliceName": "Girls",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-eh0w37bWbR"
            ],
            "type": "CHART"
          },
          "CHART-n-zGGE6S1y": {
            "children": [],
            "id": "CHART-n-zGGE6S1y",
            "meta": {
              "chartId": 5542,
              "height": 50,
              "sliceName": "Girl Name Cloud",
              "width": 4
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-eh0w37bWbR"
            ],
            "type": "CHART"
          },
          "CHART-vJIPjmcbD3": {
            "children": [],
            "id": "CHART-vJIPjmcbD3",
            "meta": {
              "chartId": 5543,
              "height": 50,
              "sliceName": "Boys",
              "width": 3
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-kzWtcvo8R1"
            ],
            "type": "CHART"
          },
          "DASHBOARD_VERSION_KEY": "v2",
          "GRID_ID": {
            "children": [
              "ROW-2n0XgiHDgs",
              "ROW--EyBZQlDi",
              "ROW-eh0w37bWbR",
              "ROW-kzWtcvo8R1"
            ],
            "id": "GRID_ID",
            "parents": [
              "ROOT_ID"
            ],
            "type": "GRID"
          },
          "HEADER_ID": {
            "id": "HEADER_ID",
            "meta": {
              "text": "Births"
            },
            "type": "HEADER"
          },
          "MARKDOWN-zaflB60tbC": {
            "children": [],
            "id": "MARKDOWN-zaflB60tbC",
            "meta": {
              "code": "<div style=\\"text-align:center\\">  <h1>Birth Names Dashboard</h1>  <img src=\\"/static/assets/images/babies.png\\" style=\\"width:50%;\\"></div>",
              "height": 34,
              "width": 6
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID",
              "ROW-2n0XgiHDgs"
            ],
            "type": "MARKDOWN"
          },
          "ROOT_ID": {
            "children": [
              "GRID_ID"
            ],
            "id": "ROOT_ID",
            "type": "ROOT"
          },
          "ROW--EyBZQlDi": {
            "children": [
              "CHART-_T6n_K9iQN",
              "CHART-6n9jxb30JG"
            ],
            "id": "ROW--EyBZQlDi",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          },
          "ROW-2n0XgiHDgs": {
            "children": [
              "CHART-eNY0tcE_ic",
              "MARKDOWN-zaflB60tbC",
              "CHART-PAXUUqwmX9"
            ],
            "id": "ROW-2n0XgiHDgs",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          },
          "ROW-eh0w37bWbR": {
            "children": [
              "CHART-g075mMgyYb",
              "CHART-n-zGGE6S1y",
              "CHART-6GdlekVise"
            ],
            "id": "ROW-eh0w37bWbR",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          },
          "ROW-kzWtcvo8R1": {
            "children": [
              "CHART-vJIPjmcbD3",
              "CHART-Jj9qh1ol-N",
              "CHART-ODvantb_bF"
            ],
            "id": "ROW-kzWtcvo8R1",
            "meta": {
              "background": "BACKGROUND_TRANSPARENT"
            },
            "parents": [
              "ROOT_ID",
              "GRID_ID"
            ],
            "type": "ROW"
          }
        }
        """

        # pylint: enable=line-too-long
    )
    pos = json.loads(js)
    # dashboard v2 doesn't allow add markup slice
    dash.slices = [slc for slc in slices if slc.viz_type != "markup"]
    update_slice_ids(pos, dash.slices)
    dash.dashboard_title = "USA Births Names"
    dash.position_json = json.dumps(pos, indent=4)
    dash.slug = "births"
    db.session.commit()
 def test_get_user_slices(self):
     self.login(username="******")
     userid = security_manager.find_user("admin").id
     url = f"/sliceasync/api/read?_flt_0_created_by={userid}"
     resp = self.client.get(url)
     self.assertEqual(resp.status_code, 200)
    def test_request_access(self):
        if app.config.get("ENABLE_ACCESS_REQUEST"):
            session = db.session
            self.logout()
            self.login(username="******")
            gamma_user = security_manager.find_user(username="******")
            security_manager.add_role("dummy_role")
            gamma_user.roles.append(security_manager.find_role("dummy_role"))
            session.commit()

            ACCESS_REQUEST = (
                "/superset/request_access?"
                "datasource_type={}&"
                "datasource_id={}&"
                "action={}&"
            )
            ROLE_GRANT_LINK = (
                '<a href="/superset/approve?datasource_type={}&datasource_id={}&'
                'created_by={}&role_to_grant={}">Grant {} Role</a>'
            )

            # Request table access, there are no roles have this table.

            table1 = (
                session.query(SqlaTable)
                .filter_by(table_name="random_time_series")
                .first()
            )
            table_1_id = table1.id

            # request access to the table
            resp = self.get_resp(ACCESS_REQUEST.format("table", table_1_id, "go"))
            assert "Access was requested" in resp
            access_request1 = self.get_access_requests("gamma", "table", table_1_id)
            assert access_request1 is not None

            # Request access, roles exist that contains the table.
            # add table to the existing roles
            table3 = (
                session.query(SqlaTable).filter_by(table_name="energy_usage").first()
            )
            table_3_id = table3.id
            table3_perm = table3.perm

            security_manager.add_role("energy_usage_role")
            alpha_role = security_manager.find_role("Alpha")
            security_manager.add_permission_role(
                alpha_role,
                security_manager.find_permission_view_menu(
                    "datasource_access", table3_perm
                ),
            )
            security_manager.add_permission_role(
                security_manager.find_role("energy_usage_role"),
                security_manager.find_permission_view_menu(
                    "datasource_access", table3_perm
                ),
            )
            session.commit()

            self.get_resp(ACCESS_REQUEST.format("table", table_3_id, "go"))
            access_request3 = self.get_access_requests("gamma", "table", table_3_id)
            approve_link_3 = ROLE_GRANT_LINK.format(
                "table", table_3_id, "gamma", "energy_usage_role", "energy_usage_role"
            )
            self.assertEqual(
                access_request3.roles_with_datasource,
                "<ul><li>{}</li></ul>".format(approve_link_3),
            )

            # Request druid access, there are no roles have this table.
            druid_ds_4 = (
                session.query(DruidDatasource)
                .filter_by(datasource_name="druid_ds_1")
                .first()
            )
            druid_ds_4_id = druid_ds_4.id

            # request access to the table
            self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_4_id, "go"))
            access_request4 = self.get_access_requests("gamma", "druid", druid_ds_4_id)

            self.assertEqual(
                access_request4.roles_with_datasource,
                "<ul></ul>".format(access_request4.id),
            )

            # Case 5. Roles exist that contains the druid datasource.
            # add druid ds to the existing roles
            druid_ds_5 = (
                session.query(DruidDatasource)
                .filter_by(datasource_name="druid_ds_2")
                .first()
            )
            druid_ds_5_id = druid_ds_5.id
            druid_ds_5_perm = druid_ds_5.perm

            druid_ds_2_role = security_manager.add_role("druid_ds_2_role")
            admin_role = security_manager.find_role("Admin")
            security_manager.add_permission_role(
                admin_role,
                security_manager.find_permission_view_menu(
                    "datasource_access", druid_ds_5_perm
                ),
            )
            security_manager.add_permission_role(
                druid_ds_2_role,
                security_manager.find_permission_view_menu(
                    "datasource_access", druid_ds_5_perm
                ),
            )
            session.commit()

            self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go"))
            access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id)
            approve_link_5 = ROLE_GRANT_LINK.format(
                "druid", druid_ds_5_id, "gamma", "druid_ds_2_role", "druid_ds_2_role"
            )
            self.assertEqual(
                access_request5.roles_with_datasource,
                "<ul><li>{}</li></ul>".format(approve_link_5),
            )

            # cleanup
            gamma_user = security_manager.find_user(username="******")
            gamma_user.roles.remove(security_manager.find_role("dummy_role"))
            session.commit()
    def test_approve(self, mock_send_mime):
        if app.config.get("ENABLE_ACCESS_REQUEST"):
            session = db.session
            TEST_ROLE_NAME = "table_role"
            security_manager.add_role(TEST_ROLE_NAME)

            # Case 1. Grant new role to the user.

            access_request1 = create_access_request(
                session, "table", "unicode_test", TEST_ROLE_NAME, "gamma"
            )
            ds_1_id = access_request1.datasource_id
            self.get_resp(
                GRANT_ROLE_REQUEST.format("table", ds_1_id, "gamma", TEST_ROLE_NAME)
            )
            # Test email content.
            self.assertTrue(mock_send_mime.called)
            call_args = mock_send_mime.call_args[0]
            self.assertEqual(
                [
                    security_manager.find_user(username="******").email,
                    security_manager.find_user(username="******").email,
                ],
                call_args[1],
            )
            self.assertEqual(
                "[Superset] Access to the datasource {} was granted".format(
                    self.get_table(ds_1_id).full_name
                ),
                call_args[2]["Subject"],
            )
            self.assertIn(TEST_ROLE_NAME, call_args[2].as_string())
            self.assertIn("unicode_test", call_args[2].as_string())

            access_requests = self.get_access_requests("gamma", "table", ds_1_id)
            # request was removed
            self.assertFalse(access_requests)
            # user was granted table_role
            user_roles = [r.name for r in security_manager.find_user("gamma").roles]
            self.assertIn(TEST_ROLE_NAME, user_roles)

            # Case 2. Extend the role to have access to the table

            access_request2 = create_access_request(
                session, "table", "energy_usage", TEST_ROLE_NAME, "gamma"
            )
            ds_2_id = access_request2.datasource_id
            energy_usage_perm = access_request2.datasource.perm

            self.client.get(
                EXTEND_ROLE_REQUEST.format(
                    "table", access_request2.datasource_id, "gamma", TEST_ROLE_NAME
                )
            )
            access_requests = self.get_access_requests("gamma", "table", ds_2_id)

            # Test email content.
            self.assertTrue(mock_send_mime.called)
            call_args = mock_send_mime.call_args[0]
            self.assertEqual(
                [
                    security_manager.find_user(username="******").email,
                    security_manager.find_user(username="******").email,
                ],
                call_args[1],
            )
            self.assertEqual(
                "[Superset] Access to the datasource {} was granted".format(
                    self.get_table(ds_2_id).full_name
                ),
                call_args[2]["Subject"],
            )
            self.assertIn(TEST_ROLE_NAME, call_args[2].as_string())
            self.assertIn("energy_usage", call_args[2].as_string())

            # request was removed
            self.assertFalse(access_requests)
            # table_role was extended to grant access to the energy_usage table/
            perm_view = security_manager.find_permission_view_menu(
                "datasource_access", energy_usage_perm
            )
            TEST_ROLE = security_manager.find_role(TEST_ROLE_NAME)
            self.assertIn(perm_view, TEST_ROLE.permissions)

            # Case 3. Grant new role to the user to access the druid datasource.

            security_manager.add_role("druid_role")
            access_request3 = create_access_request(
                session, "druid", "druid_ds_1", "druid_role", "gamma"
            )
            self.get_resp(
                GRANT_ROLE_REQUEST.format(
                    "druid", access_request3.datasource_id, "gamma", "druid_role"
                )
            )

            # user was granted table_role
            user_roles = [r.name for r in security_manager.find_user("gamma").roles]
            self.assertIn("druid_role", user_roles)

            # Case 4. Extend the role to have access to the druid datasource

            access_request4 = create_access_request(
                session, "druid", "druid_ds_2", "druid_role", "gamma"
            )
            druid_ds_2_perm = access_request4.datasource.perm

            self.client.get(
                EXTEND_ROLE_REQUEST.format(
                    "druid", access_request4.datasource_id, "gamma", "druid_role"
                )
            )
            # druid_role was extended to grant access to the druid_access_ds_2
            druid_role = security_manager.find_role("druid_role")
            perm_view = security_manager.find_permission_view_menu(
                "datasource_access", druid_ds_2_perm
            )
            self.assertIn(perm_view, druid_role.permissions)

            # cleanup
            gamma_user = security_manager.find_user(username="******")
            gamma_user.roles.remove(security_manager.find_role("druid_role"))
            gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
            session.delete(security_manager.find_role("druid_role"))
            session.delete(security_manager.find_role(TEST_ROLE_NAME))
            session.commit()
Example #48
0
    def __init__(self, *args, **kwargs):
        if (
            self.requires_examples and
            not os.environ.get('examples_loaded')
        ):
            logging.info('Loading examples')
            cli.load_examples(load_test_data=True)
            logging.info('Done loading examples')
            security_manager.sync_role_definitions()
            os.environ['examples_loaded'] = '1'
        else:
            security_manager.sync_role_definitions()
        super(SupersetTestCase, self).__init__(*args, **kwargs)
        self.client = app.test_client()
        self.maxDiff = None

        gamma_sqllab_role = security_manager.add_role('gamma_sqllab')
        for perm in security_manager.find_role('Gamma').permissions:
            security_manager.add_permission_role(gamma_sqllab_role, perm)
        utils.get_or_create_main_db()
        db_perm = self.get_main_database(security_manager.get_session).perm
        security_manager.merge_perm('database_access', db_perm)
        db_pvm = security_manager.find_permission_view_menu(
            view_menu_name=db_perm, permission_name='database_access')
        gamma_sqllab_role.permissions.append(db_pvm)
        for perm in security_manager.find_role('sql_lab').permissions:
            security_manager.add_permission_role(gamma_sqllab_role, perm)

        admin = security_manager.find_user('admin')
        if not admin:
            security_manager.add_user(
                'admin', 'admin', ' user', '*****@*****.**',
                security_manager.find_role('Admin'),
                password='******')

        gamma = security_manager.find_user('gamma')
        if not gamma:
            security_manager.add_user(
                'gamma', 'gamma', 'user', '*****@*****.**',
                security_manager.find_role('Gamma'),
                password='******')

        gamma2 = security_manager.find_user('gamma2')
        if not gamma2:
            security_manager.add_user(
                'gamma2', 'gamma2', 'user', '*****@*****.**',
                security_manager.find_role('Gamma'),
                password='******')

        gamma_sqllab_user = security_manager.find_user('gamma_sqllab')
        if not gamma_sqllab_user:
            security_manager.add_user(
                'gamma_sqllab', 'gamma_sqllab', 'user', '*****@*****.**',
                gamma_sqllab_role, password='******')

        alpha = security_manager.find_user('alpha')
        if not alpha:
            security_manager.add_user(
                'alpha', 'alpha', 'user', '*****@*****.**',
                security_manager.find_role('Alpha'),
                password='******')
        security_manager.get_session.commit()
        # create druid cluster and druid datasources
        session = db.session
        cluster = (
            session.query(DruidCluster)
            .filter_by(cluster_name='druid_test')
            .first()
        )
        if not cluster:
            cluster = DruidCluster(cluster_name='druid_test')
            session.add(cluster)
            session.commit()

            druid_datasource1 = DruidDatasource(
                datasource_name='druid_ds_1',
                cluster_name='druid_test',
            )
            session.add(druid_datasource1)
            druid_datasource2 = DruidDatasource(
                datasource_name='druid_ds_2',
                cluster_name='druid_test',
            )
            session.add(druid_datasource2)
            session.commit()
 def test_get_user_slices(self):
     self.login(username='******')
     userid = security_manager.find_user('admin').id
     url = '/sliceaddview/api/read?_flt_0_created_by={}'.format(userid)
     resp = self.client.get(url)
     self.assertEqual(resp.status_code, 200)
 def test_get_user_slices(self):
     self.login(username='******')
     userid = security_manager.find_user('admin').id
     url = '/sliceaddview/api/read?_flt_0_created_by={}'.format(userid)
     resp = self.client.get(url)
     self.assertEqual(resp.status_code, 200)
Example #51
0
    def test_approve(self, mock_send_mime):
        if app.config.get('ENABLE_ACCESS_REQUEST'):
            session = db.session
            TEST_ROLE_NAME = 'table_role'
            security_manager.add_role(TEST_ROLE_NAME)

            # Case 1. Grant new role to the user.

            access_request1 = create_access_request(
                session, 'table', 'unicode_test', TEST_ROLE_NAME, 'gamma')
            ds_1_id = access_request1.datasource_id
            self.get_resp(GRANT_ROLE_REQUEST.format(
                'table', ds_1_id, 'gamma', TEST_ROLE_NAME))
            # Test email content.
            self.assertTrue(mock_send_mime.called)
            call_args = mock_send_mime.call_args[0]
            self.assertEqual([security_manager.find_user(username='******').email,
                              security_manager.find_user(username='******').email],
                             call_args[1])
            self.assertEqual(
                '[Superset] Access to the datasource {} was granted'.format(
                    self.get_table(ds_1_id).full_name), call_args[2]['Subject'])
            self.assertIn(TEST_ROLE_NAME, call_args[2].as_string())
            self.assertIn('unicode_test', call_args[2].as_string())

            access_requests = self.get_access_requests('gamma', 'table', ds_1_id)
            # request was removed
            self.assertFalse(access_requests)
            # user was granted table_role
            user_roles = [r.name for r in security_manager.find_user('gamma').roles]
            self.assertIn(TEST_ROLE_NAME, user_roles)

            # Case 2. Extend the role to have access to the table

            access_request2 = create_access_request(
                session, 'table', 'energy_usage', TEST_ROLE_NAME, 'gamma')
            ds_2_id = access_request2.datasource_id
            energy_usage_perm = access_request2.datasource.perm

            self.client.get(EXTEND_ROLE_REQUEST.format(
                'table', access_request2.datasource_id, 'gamma', TEST_ROLE_NAME))
            access_requests = self.get_access_requests('gamma', 'table', ds_2_id)

            # Test email content.
            self.assertTrue(mock_send_mime.called)
            call_args = mock_send_mime.call_args[0]
            self.assertEqual([security_manager.find_user(username='******').email,
                              security_manager.find_user(username='******').email],
                             call_args[1])
            self.assertEqual(
                '[Superset] Access to the datasource {} was granted'.format(
                    self.get_table(ds_2_id).full_name), call_args[2]['Subject'])
            self.assertIn(TEST_ROLE_NAME, call_args[2].as_string())
            self.assertIn('energy_usage', call_args[2].as_string())

            # request was removed
            self.assertFalse(access_requests)
            # table_role was extended to grant access to the energy_usage table/
            perm_view = security_manager.find_permission_view_menu(
                'datasource_access', energy_usage_perm)
            TEST_ROLE = security_manager.find_role(TEST_ROLE_NAME)
            self.assertIn(perm_view, TEST_ROLE.permissions)

            # Case 3. Grant new role to the user to access the druid datasource.

            security_manager.add_role('druid_role')
            access_request3 = create_access_request(
                session, 'druid', 'druid_ds_1', 'druid_role', 'gamma')
            self.get_resp(GRANT_ROLE_REQUEST.format(
                'druid', access_request3.datasource_id, 'gamma', 'druid_role'))

            # user was granted table_role
            user_roles = [r.name for r in security_manager.find_user('gamma').roles]
            self.assertIn('druid_role', user_roles)

            # Case 4. Extend the role to have access to the druid datasource

            access_request4 = create_access_request(
                session, 'druid', 'druid_ds_2', 'druid_role', 'gamma')
            druid_ds_2_perm = access_request4.datasource.perm

            self.client.get(EXTEND_ROLE_REQUEST.format(
                'druid', access_request4.datasource_id, 'gamma', 'druid_role'))
            # druid_role was extended to grant access to the druid_access_ds_2
            druid_role = security_manager.find_role('druid_role')
            perm_view = security_manager.find_permission_view_menu(
                'datasource_access', druid_ds_2_perm)
            self.assertIn(perm_view, druid_role.permissions)

            # cleanup
            gamma_user = security_manager.find_user(username='******')
            gamma_user.roles.remove(security_manager.find_role('druid_role'))
            gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
            session.delete(security_manager.find_role('druid_role'))
            session.delete(security_manager.find_role(TEST_ROLE_NAME))
            session.commit()
Example #52
0
    def test_request_access(self):
        if app.config.get('ENABLE_ACCESS_REQUEST'):
            session = db.session
            self.logout()
            self.login(username='******')
            gamma_user = security_manager.find_user(username='******')
            security_manager.add_role('dummy_role')
            gamma_user.roles.append(security_manager.find_role('dummy_role'))
            session.commit()

            ACCESS_REQUEST = (
                '/superset/request_access?'
                'datasource_type={}&'
                'datasource_id={}&'
                'action={}&')
            ROLE_GRANT_LINK = (
                '<a href="/superset/approve?datasource_type={}&datasource_id={}&'
                'created_by={}&role_to_grant={}">Grant {} Role</a>')

            # Request table access, there are no roles have this table.

            table1 = session.query(SqlaTable).filter_by(
                table_name='random_time_series').first()
            table_1_id = table1.id

            # request access to the table
            resp = self.get_resp(
                ACCESS_REQUEST.format('table', table_1_id, 'go'))
            assert 'Access was requested' in resp
            access_request1 = self.get_access_requests('gamma', 'table', table_1_id)
            assert access_request1 is not None

            # Request access, roles exist that contains the table.
            # add table to the existing roles
            table3 = session.query(SqlaTable).filter_by(
                table_name='energy_usage').first()
            table_3_id = table3.id
            table3_perm = table3.perm

            security_manager.add_role('energy_usage_role')
            alpha_role = security_manager.find_role('Alpha')
            security_manager.add_permission_role(
                alpha_role,
                security_manager.find_permission_view_menu(
                    'datasource_access', table3_perm))
            security_manager.add_permission_role(
                security_manager.find_role('energy_usage_role'),
                security_manager.find_permission_view_menu(
                    'datasource_access', table3_perm))
            session.commit()

            self.get_resp(
                ACCESS_REQUEST.format('table', table_3_id, 'go'))
            access_request3 = self.get_access_requests('gamma', 'table', table_3_id)
            approve_link_3 = ROLE_GRANT_LINK.format(
                'table', table_3_id, 'gamma', 'energy_usage_role',
                'energy_usage_role')
            self.assertEqual(access_request3.roles_with_datasource,
                             '<ul><li>{}</li></ul>'.format(approve_link_3))

            # Request druid access, there are no roles have this table.
            druid_ds_4 = session.query(DruidDatasource).filter_by(
                datasource_name='druid_ds_1').first()
            druid_ds_4_id = druid_ds_4.id

            # request access to the table
            self.get_resp(ACCESS_REQUEST.format('druid', druid_ds_4_id, 'go'))
            access_request4 = self.get_access_requests('gamma', 'druid', druid_ds_4_id)

            self.assertEqual(
                access_request4.roles_with_datasource,
                '<ul></ul>'.format(access_request4.id))

            # Case 5. Roles exist that contains the druid datasource.
            # add druid ds to the existing roles
            druid_ds_5 = session.query(DruidDatasource).filter_by(
                datasource_name='druid_ds_2').first()
            druid_ds_5_id = druid_ds_5.id
            druid_ds_5_perm = druid_ds_5.perm

            druid_ds_2_role = security_manager.add_role('druid_ds_2_role')
            admin_role = security_manager.find_role('Admin')
            security_manager.add_permission_role(
                admin_role,
                security_manager.find_permission_view_menu(
                    'datasource_access', druid_ds_5_perm))
            security_manager.add_permission_role(
                druid_ds_2_role,
                security_manager.find_permission_view_menu(
                    'datasource_access', druid_ds_5_perm))
            session.commit()

            self.get_resp(ACCESS_REQUEST.format('druid', druid_ds_5_id, 'go'))
            access_request5 = self.get_access_requests(
                'gamma', 'druid', druid_ds_5_id)
            approve_link_5 = ROLE_GRANT_LINK.format(
                'druid', druid_ds_5_id, 'gamma', 'druid_ds_2_role',
                'druid_ds_2_role')
            self.assertEqual(access_request5.roles_with_datasource,
                             '<ul><li>{}</li></ul>'.format(approve_link_5))

            # cleanup
            gamma_user = security_manager.find_user(username='******')
            gamma_user.roles.remove(security_manager.find_role('dummy_role'))
            session.commit()
Example #53
0
def load_birth_names():
    """Loading birth name dataset from a zip file in the repo"""
    with gzip.open(os.path.join(DATA_FOLDER, 'birth_names.json.gz')) as f:
        pdf = pd.read_json(f)
    pdf.ds = pd.to_datetime(pdf.ds, unit='ms')
    pdf.to_sql(
        'birth_names',
        db.engine,
        if_exists='replace',
        chunksize=500,
        dtype={
            'ds': DateTime,
            'gender': String(16),
            'state': String(10),
            'name': String(255),
        },
        index=False)
    print('Done loading table!')
    print('-' * 80)

    print('Creating table [birth_names] reference')
    obj = db.session.query(TBL).filter_by(table_name='birth_names').first()
    if not obj:
        obj = TBL(table_name='birth_names')
    obj.main_dttm_col = 'ds'
    obj.database = get_or_create_main_db()
    obj.filter_select_enabled = True

    if not any(col.column_name == 'num_california' for col in obj.columns):
        obj.columns.append(TableColumn(
            column_name='num_california',
            expression="CASE WHEN state = 'CA' THEN num ELSE 0 END",
        ))

    db.session.merge(obj)
    db.session.commit()
    obj.fetch_metadata()
    tbl = obj

    defaults = {
        'compare_lag': '10',
        'compare_suffix': 'o10Y',
        'limit': '25',
        'granularity_sqla': 'ds',
        'groupby': [],
        'metric': 'sum__num',
        'metrics': ['sum__num'],
        'row_limit': config.get('ROW_LIMIT'),
        'since': '100 years ago',
        'until': 'now',
        'viz_type': 'table',
        'where': '',
        'markup_type': 'markdown',
    }

    admin = security_manager.find_user('admin')

    print('Creating some slices')
    slices = [
        Slice(
            slice_name='Girls',
            viz_type='table',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                groupby=['name'],
                filters=[{
                    'col': 'gender',
                    'op': 'in',
                    'val': ['girl'],
                }],
                row_limit=50,
                timeseries_limit_metric='sum__num')),
        Slice(
            slice_name='Boys',
            viz_type='table',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                groupby=['name'],
                filters=[{
                    'col': 'gender',
                    'op': 'in',
                    'val': ['boy'],
                }],
                row_limit=50)),
        Slice(
            slice_name='Participants',
            viz_type='big_number',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='big_number', granularity_sqla='ds',
                compare_lag='5', compare_suffix='over 5Y')),
        Slice(
            slice_name='Genders',
            viz_type='pie',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='pie', groupby=['gender'])),
        Slice(
            slice_name='Genders by State',
            viz_type='dist_bar',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                adhoc_filters=[
                    {
                        'clause': 'WHERE',
                        'expressionType': 'SIMPLE',
                        'filterOptionName': '2745eae5',
                        'comparator': ['other'],
                        'operator': 'not in',
                        'subject': 'state',
                    },
                ],
                viz_type='dist_bar',
                metrics=[
                    {
                        'expressionType': 'SIMPLE',
                        'column': {
                            'column_name': 'sum_boys',
                            'type': 'BIGINT(20)',
                        },
                        'aggregate': 'SUM',
                        'label': 'Boys',
                        'optionName': 'metric_11',
                    },
                    {
                        'expressionType': 'SIMPLE',
                        'column': {
                            'column_name': 'sum_girls',
                            'type': 'BIGINT(20)',
                        },
                        'aggregate': 'SUM',
                        'label': 'Girls',
                        'optionName': 'metric_12',
                    },
                ],
                groupby=['state'])),
        Slice(
            slice_name='Trends',
            viz_type='line',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='line', groupby=['name'],
                granularity_sqla='ds', rich_tooltip=True, show_legend=True)),
        Slice(
            slice_name='Average and Sum Trends',
            viz_type='dual_line',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='dual_line',
                metric={
                    'expressionType': 'SIMPLE',
                    'column': {
                        'column_name': 'num',
                        'type': 'BIGINT(20)',
                    },
                    'aggregate': 'AVG',
                    'label': 'AVG(num)',
                    'optionName': 'metric_vgops097wej_g8uff99zhk7',
                },
                metric_2='sum__num',
                granularity_sqla='ds')),
        Slice(
            slice_name='Title',
            viz_type='markup',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='markup', markup_type='html',
                code="""\
    <div style='text-align:center'>
        <h1>Birth Names Dashboard</h1>
        <p>
            The source dataset came from
            <a href='https://github.com/hadley/babynames' target='_blank'>[here]</a>
        </p>
        <img src='/static/assets/images/babytux.jpg'>
    </div>
    """)),
        Slice(
            slice_name='Name Cloud',
            viz_type='word_cloud',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='word_cloud', size_from='10',
                series='name', size_to='70', rotation='square',
                limit='100')),
        Slice(
            slice_name='Pivot Table',
            viz_type='pivot_table',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='pivot_table', metrics=['sum__num'],
                groupby=['name'], columns=['state'])),
        Slice(
            slice_name='Number of Girls',
            viz_type='big_number_total',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='big_number_total', granularity_sqla='ds',
                filters=[{
                    'col': 'gender',
                    'op': 'in',
                    'val': ['girl'],
                }],
                subheader='total female participants')),
        Slice(
            slice_name='Number of California Births',
            viz_type='big_number_total',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                metric={
                    'expressionType': 'SIMPLE',
                    'column': {
                        'column_name': 'num_california',
                        'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    'aggregate': 'SUM',
                    'label': 'SUM(num_california)',
                },
                viz_type='big_number_total',
                granularity_sqla='ds')),
        Slice(
            slice_name='Top 10 California Names Timeseries',
            viz_type='line',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                metrics=[{
                    'expressionType': 'SIMPLE',
                    'column': {
                        'column_name': 'num_california',
                        'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    'aggregate': 'SUM',
                    'label': 'SUM(num_california)',
                }],
                viz_type='line',
                granularity_sqla='ds',
                groupby=['name'],
                timeseries_limit_metric={
                    'expressionType': 'SIMPLE',
                    'column': {
                        'column_name': 'num_california',
                        'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    'aggregate': 'SUM',
                    'label': 'SUM(num_california)',
                },
                limit='10')),
        Slice(
            slice_name='Names Sorted by Num in California',
            viz_type='table',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                groupby=['name'],
                row_limit=50,
                timeseries_limit_metric={
                    'expressionType': 'SIMPLE',
                    'column': {
                        'column_name': 'num_california',
                        'expression': "CASE WHEN state = 'CA' THEN num ELSE 0 END",
                    },
                    'aggregate': 'SUM',
                    'label': 'SUM(num_california)',
                })),
        Slice(
            slice_name='Num Births Trend',
            viz_type='line',
            datasource_type='table',
            datasource_id=tbl.id,
            params=get_slice_json(
                defaults,
                viz_type='line')),
        Slice(
            slice_name='Daily Totals',
            viz_type='table',
            datasource_type='table',
            datasource_id=tbl.id,
            created_by=admin,
            params=get_slice_json(
                defaults,
                groupby=['ds'],
                since='40 years ago',
                until='now',
                viz_type='table')),
    ]
    for slc in slices:
        merge_slice(slc)

    print('Creating a dashboard')
    dash = db.session.query(Dash).filter_by(dashboard_title='Births').first()

    if not dash:
        dash = Dash()
    js = textwrap.dedent("""\
{
    "CHART-0dd270f0": {
        "meta": {
            "chartId": 51,
            "width": 2,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-0dd270f0",
        "children": []
    },
    "CHART-a3c21bcc": {
        "meta": {
            "chartId": 52,
            "width": 2,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-a3c21bcc",
        "children": []
    },
    "CHART-976960a5": {
        "meta": {
            "chartId": 53,
            "width": 2,
            "height": 25
        },
        "type": "CHART",
        "id": "CHART-976960a5",
        "children": []
    },
    "CHART-58575537": {
        "meta": {
            "chartId": 54,
            "width": 2,
            "height": 25
        },
        "type": "CHART",
        "id": "CHART-58575537",
        "children": []
    },
    "CHART-e9cd8f0b": {
        "meta": {
            "chartId": 55,
            "width": 8,
            "height": 38
        },
        "type": "CHART",
        "id": "CHART-e9cd8f0b",
        "children": []
    },
    "CHART-e440d205": {
        "meta": {
            "chartId": 56,
            "width": 8,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-e440d205",
        "children": []
    },
    "CHART-59444e0b": {
        "meta": {
            "chartId": 57,
            "width": 3,
            "height": 38
        },
        "type": "CHART",
        "id": "CHART-59444e0b",
        "children": []
    },
    "CHART-e2cb4997": {
        "meta": {
            "chartId": 59,
            "width": 4,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-e2cb4997",
        "children": []
    },
    "CHART-e8774b49": {
        "meta": {
            "chartId": 60,
            "width": 12,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-e8774b49",
        "children": []
    },
    "CHART-985bfd1e": {
        "meta": {
            "chartId": 61,
            "width": 4,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-985bfd1e",
        "children": []
    },
    "CHART-17f13246": {
        "meta": {
            "chartId": 62,
            "width": 4,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-17f13246",
        "children": []
    },
    "CHART-729324f6": {
        "meta": {
            "chartId": 63,
            "width": 4,
            "height": 50
        },
        "type": "CHART",
        "id": "CHART-729324f6",
        "children": []
    },
    "COLUMN-25a865d6": {
        "meta": {
            "width": 4,
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "COLUMN",
        "id": "COLUMN-25a865d6",
        "children": [
            "ROW-cc97c6ac",
            "CHART-e2cb4997"
        ]
    },
    "COLUMN-4557b6ba": {
        "meta": {
            "width": 8,
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "COLUMN",
        "id": "COLUMN-4557b6ba",
        "children": [
            "ROW-d2e78e59",
            "CHART-e9cd8f0b"
        ]
    },
    "GRID_ID": {
        "type": "GRID",
        "id": "GRID_ID",
        "children": [
            "ROW-8515ace3",
            "ROW-1890385f",
            "ROW-f0b64094",
            "ROW-be9526b8"
        ]
    },
    "HEADER_ID": {
        "meta": {
            "text": "Births"
        },
        "type": "HEADER",
        "id": "HEADER_ID"
    },
    "MARKDOWN-00178c27": {
        "meta": {
            "width": 5,
            "code": "<div style=\\"text-align:center\\">\\n <h1>Birth Names Dashboard</h1>\\n <p>\\n The source dataset came from\\n <a href=\\"https://github.com/hadley/babynames\\" target=\\"_blank\\">[here]</a>\\n </p>\\n <img src=\\"/static/assets/images/babytux.jpg\\">\\n</div>\\n",
            "height": 38
        },
        "type": "MARKDOWN",
        "id": "MARKDOWN-00178c27",
        "children": []
    },
    "ROOT_ID": {
        "type": "ROOT",
        "id": "ROOT_ID",
        "children": [
            "GRID_ID"
        ]
    },
    "ROW-1890385f": {
        "meta": {
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "ROW",
        "id": "ROW-1890385f",
        "children": [
            "CHART-e440d205",
            "CHART-0dd270f0",
            "CHART-a3c21bcc"
        ]
    },
    "ROW-8515ace3": {
        "meta": {
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "ROW",
        "id": "ROW-8515ace3",
        "children": [
            "COLUMN-25a865d6",
            "COLUMN-4557b6ba"
        ]
    },
    "ROW-be9526b8": {
        "meta": {
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "ROW",
        "id": "ROW-be9526b8",
        "children": [
            "CHART-985bfd1e",
            "CHART-17f13246",
            "CHART-729324f6"
        ]
    },
    "ROW-cc97c6ac": {
        "meta": {
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "ROW",
        "id": "ROW-cc97c6ac",
        "children": [
            "CHART-976960a5",
            "CHART-58575537"
        ]
    },
    "ROW-d2e78e59": {
        "meta": {
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "ROW",
        "id": "ROW-d2e78e59",
        "children": [
            "MARKDOWN-00178c27",
            "CHART-59444e0b"
        ]
    },
    "ROW-f0b64094": {
        "meta": {
            "background": "BACKGROUND_TRANSPARENT"
        },
        "type": "ROW",
        "id": "ROW-f0b64094",
        "children": [
            "CHART-e8774b49"
        ]
    },
    "DASHBOARD_VERSION_KEY": "v2"
}
        """)
    pos = json.loads(js)
    # dashboard v2 doesn't allow add markup slice
    dash.slices = [slc for slc in slices if slc.viz_type != 'markup']
    update_slice_ids(pos, dash.slices)
    dash.dashboard_title = 'Births'
    dash.position_json = json.dumps(pos, indent=4)
    dash.slug = 'births'
    db.session.merge(dash)
    db.session.commit()