Exemple #1
0
    def test_get_datasource(self):
        self.login(username="******")
        tbl = self.get_table_by_name("birth_names")
        self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)

        for key in self.datasource.export_fields:
            self.original_attrs[key] = getattr(self.datasource, key)

        datasource_post["id"] = tbl.id
        data = dict(data=json.dumps(datasource_post))
        self.get_json_resp("/datasource/save/", data)
        url = f"/datasource/get/{tbl.type}/{tbl.id}/"
        resp = self.get_json_resp(url)
        self.assertEqual(resp.get("type"), "table")
        col_names = {o.get("column_name") for o in resp["columns"]}
        self.assertEqual(
            col_names,
            {
                "num_boys",
                "num",
                "gender",
                "name",
                "ds",
                "state",
                "num_girls",
                "num_california",
            },
        )
 def insert_chart(
     self,
     slice_name: str,
     owners: List[int],
     datasource_id: int,
     created_by=None,
     datasource_type: str = "table",
     description: Optional[str] = None,
     viz_type: Optional[str] = None,
     params: Optional[str] = None,
     cache_timeout: Optional[int] = None,
 ) -> Slice:
     obj_owners = list()
     for owner in owners:
         user = db.session.query(security_manager.user_model).get(owner)
         obj_owners.append(user)
     datasource = ConnectorRegistry.get_datasource(
         datasource_type, datasource_id, db.session
     )
     slice = Slice(
         cache_timeout=cache_timeout,
         created_by=created_by,
         datasource_id=datasource.id,
         datasource_name=datasource.name,
         datasource_type=datasource.type,
         description=description,
         owners=obj_owners,
         params=params,
         slice_name=slice_name,
         viz_type=viz_type,
     )
     db.session.add(slice)
     db.session.commit()
     return slice
Exemple #3
0
    def test_save_duplicate_key(self):
        self.login(username="******")
        tbl_id = self.get_table_by_name("birth_names").id
        self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)

        for key in self.datasource.export_fields:
            self.original_attrs[key] = getattr(self.datasource, key)

        datasource_post_copy = deepcopy(datasource_post)
        datasource_post_copy["id"] = tbl_id
        datasource_post_copy["columns"].extend(
            [
                {
                    "column_name": "<new column>",
                    "filterable": True,
                    "groupby": True,
                    "expression": "<enter SQL expression here>",
                    "id": "somerandomid",
                },
                {
                    "column_name": "<new column>",
                    "filterable": True,
                    "groupby": True,
                    "expression": "<enter SQL expression here>",
                    "id": "somerandomid2",
                },
            ]
        )
        data = dict(data=json.dumps(datasource_post_copy))
        resp = self.get_json_resp("/datasource/save/", data, raise_on_error=False)
        self.assertIn("Duplicate column name(s): <new column>", resp["error"])
Exemple #4
0
    def export_dashboards(  # pylint: disable=too-many-locals
        cls, dashboard_ids: List[int]
    ) -> str:
        copied_dashboards = []
        datasource_ids = set()
        for dashboard_id in dashboard_ids:
            # make sure that dashboard_id is an integer
            dashboard_id = int(dashboard_id)
            dashboard = (
                db.session.query(Dashboard)
                .options(subqueryload(Dashboard.slices))
                .filter_by(id=dashboard_id)
                .first()
            )
            # remove ids and relations (like owners, created by, slices, ...)
            copied_dashboard = dashboard.copy()
            for slc in dashboard.slices:
                datasource_ids.add((slc.datasource_id, slc.datasource_type))
                copied_slc = slc.copy()
                # save original id into json
                # we need it to update dashboard's json metadata on import
                copied_slc.id = slc.id
                # add extra params for the import
                copied_slc.alter_params(
                    remote_id=slc.id,
                    datasource_name=slc.datasource.datasource_name,
                    schema=slc.datasource.schema,
                    database_name=slc.datasource.database.name,
                )
                # set slices without creating ORM relations
                slices = copied_dashboard.__dict__.setdefault("slices", [])
                slices.append(copied_slc)
            copied_dashboard.alter_params(remote_id=dashboard_id)
            copied_dashboards.append(copied_dashboard)

        eager_datasources = []
        for datasource_id, datasource_type in datasource_ids:
            eager_datasource = ConnectorRegistry.get_eager_datasource(
                db.session, datasource_type, datasource_id
            )
            copied_datasource = eager_datasource.copy()
            copied_datasource.alter_params(
                remote_id=eager_datasource.id,
                database_name=eager_datasource.database.name,
            )
            datasource_class = copied_datasource.__class__
            for field_name in datasource_class.export_children:
                field_val = getattr(eager_datasource, field_name).copy()
                # set children without creating ORM relations
                copied_datasource.__dict__[field_name] = field_val
            eager_datasources.append(copied_datasource)

        return json.dumps(
            {"dashboards": copied_dashboards, "datasources": eager_datasources},
            cls=utils.DashboardEncoder,
            indent=4,
        )
Exemple #5
0
    def test_get_datasource_with_health_check(self):
        def my_check(datasource):
            return "Warning message!"

        app.config["DATASET_HEALTH_CHECK"] = my_check
        self.login(username="******")
        tbl = self.get_table(name="birth_names")
        datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
        assert datasource.health_check_message == "Warning message!"
        app.config["DATASET_HEALTH_CHECK"] = None
Exemple #6
0
    def test_get_datasource_failed(self):
        pytest.raises(
            DatasetNotFoundError,
            lambda: ConnectorRegistry.get_datasource("table", 9999999, db.session),
        )

        self.login(username="******")
        resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
        self.assertEqual(resp.get("error"), "Dataset does not exist")

        resp = self.get_json_resp(
            "/datasource/get/invalid-datasource-type/500000/", raise_on_error=False
        )
        self.assertEqual(resp.get("error"), "Dataset does not exist")
Exemple #7
0
    def test_external_metadata_error_return_400(self, mock_get_datasource):
        self.login(username="******")
        tbl = self.get_table(name="birth_names")
        url = f"/datasource/external_metadata/table/{tbl.id}/"

        mock_get_datasource.side_effect = SupersetGenericDBErrorException("oops")

        pytest.raises(
            SupersetGenericDBErrorException,
            lambda: ConnectorRegistry.get_datasource(
                "table", tbl.id, db.session
            ).external_metadata(),
        )

        resp = self.client.get(url)
        assert resp.status_code == 400
Exemple #8
0
def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
    table_id = db.session.query(SqlaTable).filter_by(table_name="birth_names").one().id
    datasource = ConnectorRegistry.get_datasource("table", table_id, db.session)
    columns = [column for column in datasource.columns]
    metrics = [metric for metric in datasource.metrics]

    engine = get_example_database().get_sqla_engine()
    engine.execute("DROP TABLE IF EXISTS birth_names")
    for column in columns:
        db.session.delete(column)
    for metric in metrics:
        db.session.delete(metric)

    dash = db.session.query(Dashboard).filter_by(id=dash_id).first()

    db.session.delete(dash)
    for slice_id in slices_ids:
        db.session.query(Slice).filter_by(id=slice_id).delete()
    db.session.commit()
Exemple #9
0
    def import_obj(
        cls,
        slc_to_import: "Slice",
        slc_to_override: Optional["Slice"],
        import_time: Optional[int] = None,
    ) -> int:
        """Inserts or overrides slc in the database.

        remote_id and import_time fields in params_dict are set to track the
        slice origin and ensure correct overrides for multiple imports.
        Slice.perm is used to find the datasources and connect them.

        :param Slice slc_to_import: Slice object to import
        :param Slice slc_to_override: Slice to replace, id matches remote_id
        :returns: The resulting id for the imported slice
        :rtype: int
        """
        session = db.session
        make_transient(slc_to_import)
        slc_to_import.dashboards = []
        slc_to_import.alter_params(remote_id=slc_to_import.id,
                                   import_time=import_time)

        slc_to_import = slc_to_import.copy()
        slc_to_import.reset_ownership()
        params = slc_to_import.params_dict
        datasource = ConnectorRegistry.get_datasource_by_name(
            session,
            slc_to_import.datasource_type,
            params["datasource_name"],
            params["schema"],
            params["database_name"],
        )
        slc_to_import.datasource_id = datasource.id  # type: ignore
        if slc_to_override:
            slc_to_override.override(slc_to_import)
            session.flush()
            return slc_to_override.id
        session.add(slc_to_import)
        logger.info("Final slice: %s", str(slc_to_import.to_json()))
        session.flush()
        return slc_to_import.id
    def test_get_datasource_with_health_check(self):
        def my_check(datasource):
            return "Warning message!"

        app.config["DATASET_HEALTH_CHECK"] = my_check
        my_check.version = 0.1

        self.login(username="******")
        tbl = self.get_table_by_name("birth_names")
        self.datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)

        for key in self.datasource.export_fields:
            self.original_attrs[key] = getattr(self.datasource, key)

        url = f"/datasource/get/{tbl.type}/{tbl.id}/"
        tbl.health_check(commit=True, force=True)
        resp = self.get_json_resp(url)
        self.assertEqual(resp["health_check_message"], "Warning message!")

        del app.config["DATASET_HEALTH_CHECK"]
Exemple #11
0
    def test_save(self):
        self.login(username="******")
        tbl_id = self.get_table_by_name("birth_names").id

        self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)

        for key in self.datasource.export_fields:
            self.original_attrs[key] = getattr(self.datasource, key)

        datasource_post["id"] = tbl_id
        data = dict(data=json.dumps(datasource_post))
        resp = self.get_json_resp("/datasource/save/", data)
        for k in datasource_post:
            if k == "columns":
                self.compare_lists(datasource_post[k], resp[k], "column_name")
            elif k == "metrics":
                self.compare_lists(datasource_post[k], resp[k], "metric_name")
            elif k == "database":
                self.assertEqual(resp[k]["id"], datasource_post[k]["id"])
            else:
                self.assertEqual(resp[k], datasource_post[k])
def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
    schema = get_example_default_schema()

    table_id = (db.session.query(SqlaTable).filter_by(table_name="birth_names",
                                                      schema=schema).one().id)
    datasource = ConnectorRegistry.get_datasource("table", table_id,
                                                  db.session)
    columns = [column for column in datasource.columns]
    metrics = [metric for metric in datasource.metrics]

    for column in columns:
        db.session.delete(column)
    for metric in metrics:
        db.session.delete(metric)

    dash = db.session.query(Dashboard).filter_by(id=dash_id).first()

    db.session.delete(dash)
    for slice_id in slices_ids:
        db.session.query(Slice).filter_by(id=slice_id).delete()
    db.session.commit()
Exemple #13
0
    def test_change_database(self):
        self.login(username="******")
        tbl = self.get_table_by_name("birth_names")
        tbl_id = tbl.id
        db_id = tbl.database_id
        datasource_post["id"] = tbl_id

        self.datasource = ConnectorRegistry.get_datasource("table", tbl_id, db.session)

        for key in self.datasource.export_fields:
            self.original_attrs[key] = getattr(self.datasource, key)

        new_db = self.create_fake_db()

        datasource_post["database"]["id"] = new_db.id
        resp = self.save_datasource_from_dict(datasource_post)
        self.assertEqual(resp["database"]["id"], new_db.id)

        datasource_post["database"]["id"] = db_id
        resp = self.save_datasource_from_dict(datasource_post)
        self.assertEqual(resp["database"]["id"], db_id)

        self.delete_fake_db()