Example #1
0
    def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:

        return DatasourceDAO.get_datasource(
            session=db.session,
            datasource_type=DatasourceType(datasource["type"]),
            datasource_id=int(datasource["id"]),
        )
Example #2
0
    def test_query_cache_key_changes_when_datasource_is_updated(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        # construct baseline query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        # make temporary change and revert it to refresh the changed_on property
        datasource = DatasourceDAO.get_datasource(
            session=db.session,
            datasource_type=DatasourceType(payload["datasource"]["type"]),
            datasource_id=payload["datasource"]["id"],
        )
        description_original = datasource.description
        datasource.description = "temporary description"
        db.session.commit()
        datasource.description = description_original
        db.session.commit()

        # create new QueryContext with unchanged attributes, extract new query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_new = query_context.query_cache_key(query_object)

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)
Example #3
0
    def test_query_cache_key_changes_when_metric_is_updated(self):
        self.login(username="******")
        payload = get_query_context("birth_names")

        # make temporary change and revert it to refresh the changed_on property
        datasource = DatasourceDAO.get_datasource(
            session=db.session,
            datasource_type=DatasourceType(payload["datasource"]["type"]),
            datasource_id=payload["datasource"]["id"],
        )

        datasource.metrics.append(SqlMetric(metric_name="foo", expression="select 1;"))
        db.session.commit()

        # construct baseline query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_original = query_context.query_cache_key(query_object)

        # wait a second since mysql records timestamps in second granularity
        time.sleep(1)

        datasource.metrics[0].expression = "select 2;"
        db.session.commit()

        # create new QueryContext with unchanged attributes, extract new query_cache_key
        query_context = ChartDataQueryContextSchema().load(payload)
        query_object = query_context.queries[0]
        cache_key_new = query_context.query_cache_key(query_object)

        datasource.metrics = []
        db.session.commit()

        # the new cache_key should be different due to updated datasource
        self.assertNotEqual(cache_key_original, cache_key_new)
Example #4
0
def get_datasource_by_id(datasource_id: int,
                         datasource_type: str) -> BaseDatasource:
    try:
        return DatasourceDAO.get_datasource(db.session,
                                            DatasourceType(datasource_type),
                                            datasource_id)
    except DatasourceNotFound as ex:
        raise DatasourceNotFoundValidationError() from ex
Example #5
0
def test_get_datasource_query(session_with_data: Session) -> None:
    from superset.datasource.dao import DatasourceDAO
    from superset.models.sql_lab import Query

    result = DatasourceDAO.get_datasource(datasource_type=DatasourceType.QUERY,
                                          datasource_id=1,
                                          session=session_with_data)

    assert result.id == 1
    assert isinstance(result, Query)
Example #6
0
    def test_get_datasource_invalid_datasource_failed(self):
        from superset.datasource.dao import DatasourceDAO

        pytest.raises(
            DatasourceTypeNotSupportedError,
            lambda: DatasourceDAO.get_datasource(db.session, "druid", 9999999),
        )

        self.login(username="******")
        resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
        self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
Example #7
0
    def test_get_datasource_failed(self):
        from superset.datasource.dao import DatasourceDAO

        pytest.raises(
            DatasourceNotFound,
            lambda: DatasourceDAO.get_datasource(db.session, "table", 9999999),
        )

        self.login(username="******")
        resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False)
        self.assertEqual(resp.get("error"), "Datasource does not exist")
Example #8
0
def test_get_datasource_saved_query(app_context: None,
                                    session_with_data: Session) -> None:
    from superset.datasource.dao import DatasourceDAO
    from superset.models.sql_lab import SavedQuery

    result = DatasourceDAO.get_datasource(
        datasource_type=DatasourceType.SAVEDQUERY,
        datasource_id=1,
        session=session_with_data,
    )

    assert result.id == 1
    assert isinstance(result, SavedQuery)
Example #9
0
 def external_metadata(self, datasource_type: str,
                       datasource_id: int) -> FlaskResponse:
     """Gets column info from the source system"""
     datasource = DatasourceDAO.get_datasource(
         db.session,
         DatasourceType(datasource_type),
         datasource_id,
     )
     try:
         external_metadata = datasource.external_metadata()
     except SupersetException as ex:
         return json_error_response(str(ex), status=400)
     return self.json_response(external_metadata)
Example #10
0
def test_get_datasource_sqlatable(session_with_data: Session) -> None:
    from superset.connectors.sqla.models import SqlaTable
    from superset.datasource.dao import DatasourceDAO

    result = DatasourceDAO.get_datasource(
        datasource_type=DatasourceType.TABLE,
        datasource_id=1,
        session=session_with_data,
    )

    assert 1 == result.id
    assert "my_sqla_table" == result.table_name
    assert isinstance(result, SqlaTable)
Example #11
0
def test_get_datasource_w_str_param(session_with_data: Session) -> None:
    from superset.connectors.sqla.models import SqlaTable
    from superset.datasets.models import Dataset
    from superset.datasource.dao import DatasourceDAO
    from superset.tables.models import Table

    assert isinstance(
        DatasourceDAO.get_datasource(
            datasource_type="table",
            datasource_id=1,
            session=session_with_data,
        ),
        SqlaTable,
    )

    assert isinstance(
        DatasourceDAO.get_datasource(
            datasource_type="sl_table",
            datasource_id=1,
            session=session_with_data,
        ),
        Table,
    )
Example #12
0
def test_get_datasource_sl_dataset(session_with_data: Session) -> None:
    from superset.datasets.models import Dataset
    from superset.datasource.dao import DatasourceDAO

    # todo(hugh): This will break once we remove the dual write
    # update the datsource_id=1 and this will pass again
    result = DatasourceDAO.get_datasource(
        datasource_type=DatasourceType.DATASET,
        datasource_id=2,
        session=session_with_data,
    )

    assert result.id == 2
    assert isinstance(result, Dataset)
Example #13
0
def test_get_datasource_sl_table(app_context: None,
                                 session_with_data: Session) -> None:
    from superset.datasource.dao import DatasourceDAO
    from superset.tables.models import Table

    # todo(hugh): This will break once we remove the dual write
    # update the datsource_id=1 and this will pass again
    result = DatasourceDAO.get_datasource(
        datasource_type=DatasourceType.SLTABLE,
        datasource_id=2,
        session=session_with_data,
    )

    assert result.id == 2
    assert isinstance(result, Table)
Example #14
0
def get_viz(
    form_data: FormData,
    datasource_type: str,
    datasource_id: int,
    force: bool = False,
    force_cached: bool = False,
) -> BaseViz:
    viz_type = form_data.get("viz_type", "table")
    datasource = DatasourceDAO.get_datasource(
        db.session,
        DatasourceType(datasource_type),
        datasource_id,
    )
    viz_obj = viz.viz_types[viz_type](
        datasource, form_data=form_data, force=force, force_cached=force_cached
    )
    return viz_obj
Example #15
0
    def save(self) -> FlaskResponse:
        data = request.form.get("data")
        if not isinstance(data, str):
            return json_error_response(_("Request missing data field."),
                                       status=500)

        datasource_dict = json.loads(data)
        datasource_id = datasource_dict.get("id")
        datasource_type = datasource_dict.get("type")
        database_id = datasource_dict["database"].get("id")
        orm_datasource = DatasourceDAO.get_datasource(
            db.session, DatasourceType(datasource_type), datasource_id)
        orm_datasource.database_id = database_id

        if "owners" in datasource_dict and orm_datasource.owner_class is not None:
            # Check ownership
            try:
                security_manager.raise_for_ownership(orm_datasource)
            except SupersetSecurityException as ex:
                raise DatasetForbiddenError() from ex

        datasource_dict["owners"] = populate_owners(datasource_dict["owners"],
                                                    default_to_user=False)

        duplicates = [
            name for name, count in Counter(
                [col["column_name"]
                 for col in datasource_dict["columns"]]).items() if count > 1
        ]
        if duplicates:
            return json_error_response(
                _(
                    "Duplicate column name(s): %(columns)s",
                    columns=",".join(duplicates),
                ),
                status=409,
            )
        orm_datasource.update_from_object(datasource_dict)
        data = orm_datasource.data
        db.session.commit()

        return self.json_response(sanitize_datasource_data(data))
Example #16
0
def get_samples(  # pylint: disable=too-many-arguments,too-many-locals
    datasource_type: str,
    datasource_id: int,
    force: bool = False,
    page: int = 1,
    per_page: int = 1000,
    payload: Optional[SamplesPayloadSchema] = None,
) -> Dict[str, Any]:
    datasource = DatasourceDAO.get_datasource(
        session=db.session,
        datasource_type=datasource_type,
        datasource_id=datasource_id,
    )

    limit_clause = get_limit_clause(page, per_page)

    # todo(yongjie): Constructing count(*) and samples in the same query_context,
    #  then remove query_type==SAMPLES
    # constructing samples query
    samples_instance = QueryContextFactory().create(
        datasource={
            "type": datasource.type,
            "id": datasource.id,
        },
        queries=[{
            **payload,
            **limit_clause
        } if payload else limit_clause],
        result_type=ChartDataResultType.SAMPLES,
        force=force,
    )

    # constructing count(*) query
    count_star_metric = {
        "metrics": [{
            "expressionType": "SQL",
            "sqlExpression": "COUNT(*)",
            "label": "COUNT(*)",
        }]
    }
    count_star_instance = QueryContextFactory().create(
        datasource={
            "type": datasource.type,
            "id": datasource.id,
        },
        queries=[{
            **payload,
            **count_star_metric
        } if payload else count_star_metric],
        result_type=ChartDataResultType.FULL,
        force=force,
    )
    samples_results = samples_instance.get_payload()
    count_star_results = count_star_instance.get_payload()

    try:
        sample_data = samples_results["queries"][0]
        count_star_data = count_star_results["queries"][0]
        failed_status = (sample_data.get("status") == QueryStatus.FAILED or
                         count_star_data.get("status") == QueryStatus.FAILED)
        error_msg = sample_data.get("error") or count_star_data.get("error")
        if failed_status and error_msg:
            cache_key = sample_data.get("cache_key")
            QueryCacheManager.delete(cache_key, region=CacheRegion.DATA)
            raise DatasetSamplesFailedError(error_msg)

        sample_data["page"] = page
        sample_data["per_page"] = per_page
        sample_data["total_count"] = count_star_data["data"][0]["COUNT(*)"]
        return sample_data
    except (IndexError, KeyError) as exc:
        raise DatasetSamplesFailedError from exc
Example #17
0
def create_query_object_factory() -> QueryObjectFactory:
    return QueryObjectFactory(config, DatasourceDAO(), db.session)
Example #18
0
    def run(self) -> Optional[Dict[str, Any]]:
        initial_form_data = {}

        if self._permalink_key is not None:
            command = GetExplorePermalinkCommand(self._permalink_key)
            permalink_value = command.run()
            if not permalink_value:
                raise ExplorePermalinkGetFailedError()
            state = permalink_value["state"]
            initial_form_data = state["formData"]
            url_params = state.get("urlParams")
            if url_params:
                initial_form_data["url_params"] = dict(url_params)
        elif self._form_data_key:
            parameters = FormDataCommandParameters(key=self._form_data_key)
            value = GetFormDataCommand(parameters).run()
            initial_form_data = json.loads(value) if value else {}

        message = None

        if not initial_form_data:
            if self._slice_id:
                initial_form_data["slice_id"] = self._slice_id
                if self._form_data_key:
                    message = _(
                        "Form data not found in cache, reverting to chart metadata."
                    )
            elif self._dataset_id:
                initial_form_data[
                    "datasource"
                ] = f"{self._dataset_id}__{self._dataset_type}"
                if self._form_data_key:
                    message = _(
                        "Form data not found in cache, reverting to dataset metadata."
                    )

        form_data, slc = get_form_data(
            use_slice_data=True, initial_form_data=initial_form_data
        )
        try:
            self._dataset_id, self._dataset_type = get_datasource_info(
                self._dataset_id, self._dataset_type, form_data
            )
        except SupersetException:
            self._dataset_id = None
            # fallback unkonw datasource to table type
            self._dataset_type = SqlaTable.type

        dataset: Optional[BaseDatasource] = None
        if self._dataset_id is not None:
            try:
                dataset = DatasourceDAO.get_datasource(
                    db.session, cast(str, self._dataset_type), self._dataset_id
                )
            except DatasetNotFoundError:
                pass
        dataset_name = dataset.name if dataset else _("[Missing Dataset]")

        if dataset:
            if app.config["ENABLE_ACCESS_REQUEST"] and (
                not security_manager.can_access_datasource(dataset)
            ):
                message = __(security_manager.get_datasource_access_error_msg(dataset))
                raise DatasetAccessDeniedError(
                    message=message,
                    dataset_type=self._dataset_type,
                    dataset_id=self._dataset_id,
                )

        viz_type = form_data.get("viz_type")
        if not viz_type and dataset and dataset.default_endpoint:
            raise WrongEndpointError(redirect=dataset.default_endpoint)

        form_data["datasource"] = (
            str(self._dataset_id) + "__" + cast(str, self._dataset_type)
        )

        # On explore, merge legacy and extra filters into the form data
        utils.convert_legacy_filters_into_adhoc(form_data)
        utils.merge_extra_filters(form_data)

        dummy_dataset_data: Dict[str, Any] = {
            "type": self._dataset_type,
            "name": dataset_name,
            "columns": [],
            "metrics": [],
            "database": {"id": 0, "backend": ""},
        }
        try:
            dataset_data = dataset.data if dataset else dummy_dataset_data
        except (SupersetException, SQLAlchemyError):
            dataset_data = dummy_dataset_data

        return {
            "dataset": sanitize_datasource_data(dataset_data),
            "form_data": form_data,
            "slice": slc.data if slc else None,
            "message": message,
        }
Example #19
0
    def export_dashboards(  # pylint: disable=too-many-locals
            cls, dashboard_ids: List[int]) -> str:
        copied_dashboards = []
        datasource_ids = set()
        for dashboard_id in dashboard_ids:
            # make sure that dashboard_id is an integer
            dashboard_id = int(dashboard_id)
            dashboard = (db.session.query(Dashboard).options(
                subqueryload(
                    Dashboard.slices)).filter_by(id=dashboard_id).first())
            # remove ids and relations (like owners, created by, slices, ...)
            copied_dashboard = dashboard.copy()
            for slc in dashboard.slices:
                datasource_ids.add((slc.datasource_id, slc.datasource_type))
                copied_slc = slc.copy()
                # save original id into json
                # we need it to update dashboard's json metadata on import
                copied_slc.id = slc.id
                # add extra params for the import
                copied_slc.alter_params(
                    remote_id=slc.id,
                    datasource_name=slc.datasource.datasource_name,
                    schema=slc.datasource.schema,
                    database_name=slc.datasource.database.name,
                )
                # set slices without creating ORM relations
                slices = copied_dashboard.__dict__.setdefault("slices", [])
                slices.append(copied_slc)

            json_metadata = json.loads(dashboard.json_metadata)
            native_filter_configuration: List[Dict[
                str, Any]] = json_metadata.get("native_filter_configuration",
                                               [])
            for native_filter in native_filter_configuration:
                session = db.session()
                for target in native_filter.get("targets", []):
                    id_ = target.get("datasetId")
                    if id_ is None:
                        continue
                    datasource = DatasourceDAO.get_datasource(
                        session, utils.DatasourceType.TABLE, id_)
                    datasource_ids.add((datasource.id, datasource.type))

            copied_dashboard.alter_params(remote_id=dashboard_id)
            copied_dashboards.append(copied_dashboard)

        eager_datasources = []
        for datasource_id, _ in datasource_ids:
            eager_datasource = SqlaTable.get_eager_sqlatable_datasource(
                db.session, datasource_id)
            copied_datasource = eager_datasource.copy()
            copied_datasource.alter_params(
                remote_id=eager_datasource.id,
                database_name=eager_datasource.database.name,
            )
            datasource_class = copied_datasource.__class__
            for field_name in datasource_class.export_children:
                field_val = getattr(eager_datasource, field_name).copy()
                # set children without creating ORM relations
                copied_datasource.__dict__[field_name] = field_val
            eager_datasources.append(copied_datasource)

        return json.dumps(
            {
                "dashboards": copied_dashboards,
                "datasources": eager_datasources
            },
            cls=utils.DashboardEncoder,
            indent=4,
        )
Example #20
0
 def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
     datasource = DatasourceDAO.get_datasource(
         db.session, DatasourceType(datasource_type), datasource_id)
     return self.json_response(sanitize_datasource_data(datasource.data))