Exemple #1
0
    def _import_bundle(self, session: Session) -> None:
        # first import databases
        database_ids: Dict[str, int] = {}
        for file_name, config in self._configs.items():
            if file_name.startswith("databases/"):
                database = import_database(session, config, overwrite=True)
                database_ids[str(database.uuid)] = database.id

        # import related datasets
        for file_name, config in self._configs.items():
            if (file_name.startswith("datasets/")
                    and config["database_uuid"] in database_ids):
                config["database_id"] = database_ids[config["database_uuid"]]
                # overwrite=False prevents deleting any non-imported columns/metrics
                import_dataset(session, config, overwrite=False)
Exemple #2
0
def test_import_dataset_managed_externally(app_context: None,
                                           session: Session) -> None:
    """
    Test importing a dataset that is managed externally.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database
    from tests.integration_tests.fixtures.importexport import dataset_config

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    config = copy.deepcopy(dataset_config)
    config["is_managed_externally"] = True
    config["external_url"] = "https://example.org/my_table"
    config["database_id"] = database.id

    sqla_table = import_dataset(session, config)
    assert sqla_table.is_managed_externally is True
    assert sqla_table.external_url == "https://example.org/my_table"
Exemple #3
0
    def _import(session: Session, configs: Dict[str, Any]) -> None:
        # import databases first
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("databases/"):
                database = import_database(session, config, overwrite=True)
                database_ids[str(database.uuid)] = database.id

        # import saved queries
        for file_name, config in configs.items():
            if file_name.startswith("queries/"):
                config["db_id"] = database_ids[config["database_uuid"]]
                import_saved_query(session, config, overwrite=True)

        # import datasets
        dataset_info: Dict[str, Dict[str, Any]] = {}
        for file_name, config in configs.items():
            if file_name.startswith("datasets/"):
                config["database_id"] = database_ids[config["database_uuid"]]
                dataset = import_dataset(session, config, overwrite=True)
                dataset_info[str(dataset.uuid)] = {
                    "datasource_id": dataset.id,
                    "datasource_type": dataset.datasource_type,
                    "datasource_name": dataset.table_name,
                }

        # import charts
        chart_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("charts/"):
                config.update(dataset_info[config["dataset_uuid"]])
                chart = import_chart(session, config, overwrite=True)
                chart_ids[str(chart.uuid)] = chart.id

        # store the existing relationship between dashboards and charts
        existing_relationships = session.execute(
            select(
                [dashboard_slices.c.dashboard_id,
                 dashboard_slices.c.slice_id])).fetchall()

        # import dashboards
        dashboard_chart_ids: List[Tuple[int, int]] = []
        for file_name, config in configs.items():
            if file_name.startswith("dashboards/"):
                config = update_id_refs(config, chart_ids, dataset_info)
                dashboard = import_dashboard(session, config, overwrite=True)
                for uuid in find_chart_uuids(config["position"]):
                    if uuid not in chart_ids:
                        break
                    chart_id = chart_ids[uuid]
                    if (dashboard.id, chart_id) not in existing_relationships:
                        dashboard_chart_ids.append((dashboard.id, chart_id))

        # set ref in the dashboard_slices table
        values = [{
            "dashboard_id": dashboard_id,
            "slice_id": chart_id
        } for (dashboard_id, chart_id) in dashboard_chart_ids]
        # pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
        session.execute(dashboard_slices.insert(), values)
Exemple #4
0
    def _import(session: Session,
                configs: Dict[str, Any],
                overwrite: bool = False) -> None:
        # discover datasets associated with charts
        dataset_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith("charts/"):
                dataset_uuids.add(config["dataset_uuid"])

        # discover databases associated with datasets
        database_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith(
                    "datasets/") and config["uuid"] in dataset_uuids:
                database_uuids.add(config["database_uuid"])

        # import related databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith(
                    "databases/") and config["uuid"] in database_uuids:
                database = import_database(session, config, overwrite=False)
                database_ids[str(database.uuid)] = database.id

        # import datasets with the correct parent ref
        datasets: Dict[str, SqlaTable] = {}
        for file_name, config in configs.items():
            if (file_name.startswith("datasets/")
                    and config["database_uuid"] in database_ids):
                config["database_id"] = database_ids[config["database_uuid"]]
                dataset = import_dataset(session, config, overwrite=False)
                datasets[str(dataset.uuid)] = dataset

        # import charts with the correct parent ref
        for file_name, config in configs.items():
            if file_name.startswith(
                    "charts/") and config["dataset_uuid"] in datasets:
                # update datasource id, type, and name
                dataset = datasets[config["dataset_uuid"]]
                config.update({
                    "datasource_id":
                    dataset.id,
                    "datasource_type":
                    "view" if dataset.is_sqllab_view else "table",
                    "datasource_name":
                    dataset.table_name,
                })
                config["params"].update({"datasource": dataset.uid})
                if config["query_context"]:
                    # TODO (betodealmeida): export query_context as object, not string
                    query_context = json.loads(config["query_context"])
                    query_context["datasource"] = {
                        "id": dataset.id,
                        "type": "table"
                    }
                    config["query_context"] = json.dumps(query_context)

                import_chart(session, config, overwrite=overwrite)
Exemple #5
0
    def _import_bundle(self, session: Session) -> None:
        # discover databases associated with datasets
        database_uuids: Set[str] = set()
        for file_name, config in self._configs.items():
            if file_name.startswith("datasets/"):
                database_uuids.add(config["database_uuid"])

        # import related databases
        database_ids: Dict[str, int] = {}
        for file_name, config in self._configs.items():
            if file_name.startswith(
                    "databases/") and config["uuid"] in database_uuids:
                database = import_database(session, config, overwrite=False)
                database_ids[str(database.uuid)] = database.id

        # import datasets with the correct parent ref
        for file_name, config in self._configs.items():
            if (file_name.startswith("datasets/")
                    and config["database_uuid"] in database_ids):
                config["database_id"] = database_ids[config["database_uuid"]]
                import_dataset(session, config, overwrite=True)
Exemple #6
0
    def _import(session: Session,
                configs: Dict[str, Any],
                overwrite: bool = False) -> None:
        # discover datasets associated with charts
        dataset_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith("charts/"):
                dataset_uuids.add(config["dataset_uuid"])

        # discover databases associated with datasets
        database_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith(
                    "datasets/") and config["uuid"] in dataset_uuids:
                database_uuids.add(config["database_uuid"])

        # import related databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith(
                    "databases/") and config["uuid"] in database_uuids:
                database = import_database(session, config, overwrite=False)
                database_ids[str(database.uuid)] = database.id

        # import datasets with the correct parent ref
        datasets: Dict[str, SqlaTable] = {}
        for file_name, config in configs.items():
            if (file_name.startswith("datasets/")
                    and config["database_uuid"] in database_ids):
                config["database_id"] = database_ids[config["database_uuid"]]
                dataset = import_dataset(session, config, overwrite=False)
                datasets[str(dataset.uuid)] = dataset

        # import charts with the correct parent ref
        for file_name, config in configs.items():
            if file_name.startswith(
                    "charts/") and config["dataset_uuid"] in datasets:
                # update datasource id, type, and name
                dataset = datasets[config["dataset_uuid"]]
                config.update({
                    "datasource_id": dataset.id,
                    "datasource_type": "table",
                    "datasource_name": dataset.table_name,
                })
                config["params"].update({"datasource": dataset.uid})

                if "query_context" in config:
                    del config["query_context"]

                import_chart(session, config, overwrite=overwrite)
Exemple #7
0
    def _import(session: Session,
                configs: Dict[str, Any],
                overwrite: bool = False) -> None:
        # discover datasets associated with charts
        dataset_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith("charts/"):
                dataset_uuids.add(config["dataset_uuid"])

        # discover databases associated with datasets
        database_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith(
                    "datasets/") and config["uuid"] in dataset_uuids:
                database_uuids.add(config["database_uuid"])

        # import related databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith(
                    "databases/") and config["uuid"] in database_uuids:
                database = import_database(session, config, overwrite=False)
                database_ids[str(database.uuid)] = database.id

        # import datasets with the correct parent ref
        dataset_info: Dict[str, Dict[str, Any]] = {}
        for file_name, config in configs.items():
            if (file_name.startswith("datasets/")
                    and config["database_uuid"] in database_ids):
                config["database_id"] = database_ids[config["database_uuid"]]
                dataset = import_dataset(session, config, overwrite=False)
                dataset_info[str(dataset.uuid)] = {
                    "datasource_id": dataset.id,
                    "datasource_type":
                    "view" if dataset.is_sqllab_view else "table",
                    "datasource_name": dataset.table_name,
                }

        # import charts with the correct parent ref
        for file_name, config in configs.items():
            if (file_name.startswith("charts/")
                    and config["dataset_uuid"] in dataset_info):
                # update datasource id, type, and name
                config.update(dataset_info[config["dataset_uuid"]])
                import_chart(session, config, overwrite=overwrite)
Exemple #8
0
    def _import(session: Session,
                configs: Dict[str, Any],
                overwrite: bool = False) -> None:
        # discover charts associated with dashboards
        chart_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith("dashboards/"):
                chart_uuids.update(find_chart_uuids(config["position"]))

        # discover datasets associated with charts
        dataset_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith(
                    "charts/") and config["uuid"] in chart_uuids:
                dataset_uuids.add(config["dataset_uuid"])

        # discover databases associated with datasets
        database_uuids: Set[str] = set()
        for file_name, config in configs.items():
            if file_name.startswith(
                    "datasets/") and config["uuid"] in dataset_uuids:
                database_uuids.add(config["database_uuid"])

        # import related databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith(
                    "databases/") and config["uuid"] in database_uuids:
                database = import_database(session, config, overwrite=False)
                database_ids[str(database.uuid)] = database.id

        # import datasets with the correct parent ref
        dataset_info: Dict[str, Dict[str, Any]] = {}
        for file_name, config in configs.items():
            if (file_name.startswith("datasets/")
                    and config["database_uuid"] in database_ids):
                config["database_id"] = database_ids[config["database_uuid"]]
                dataset = import_dataset(session, config, overwrite=False)
                dataset_info[str(dataset.uuid)] = {
                    "datasource_id": dataset.id,
                    "datasource_type":
                    "view" if dataset.is_sqllab_view else "table",
                    "datasource_name": dataset.table_name,
                }

        # import charts with the correct parent ref
        chart_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if (file_name.startswith("charts/")
                    and config["dataset_uuid"] in dataset_info):
                # update datasource id, type, and name
                config.update(dataset_info[config["dataset_uuid"]])
                chart = import_chart(session, config, overwrite=False)
                chart_ids[str(chart.uuid)] = chart.id

        # store the existing relationship between dashboards and charts
        existing_relationships = session.execute(
            select(
                [dashboard_slices.c.dashboard_id,
                 dashboard_slices.c.slice_id])).fetchall()

        # import dashboards
        dashboard_chart_ids: List[Tuple[int, int]] = []
        for file_name, config in configs.items():
            if file_name.startswith("dashboards/"):
                dashboard = import_dashboard(session,
                                             config,
                                             overwrite=overwrite)

                for uuid in find_chart_uuids(config["position"]):
                    chart_id = chart_ids[uuid]
                    if (dashboard.id, chart_id) not in existing_relationships:
                        dashboard_chart_ids.append((dashboard.id, chart_id))

        # set ref in the dashboard_slices table
        values = [{
            "dashboard_id": dashboard_id,
            "slice_id": chart_id
        } for (dashboard_id, chart_id) in dashboard_chart_ids]
        # pylint: disable=no-value-for-parameter (sqlalchemy/issues/4656)
        session.execute(dashboard_slices.insert(), values)
Exemple #9
0
    def _import(session: Session,
                configs: Dict[str, Any],
                overwrite: bool = False) -> None:
        # import databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("databases/"):
                database = import_database(session,
                                           config,
                                           overwrite=overwrite)
                database_ids[str(database.uuid)] = database.id

        # import datasets
        # TODO (betodealmeida): once we have all examples being imported we can
        # have a stable UUID for the database stored in the dataset YAML; for
        # now we need to fetch the current ID.
        examples_id = (db.session.query(Database).filter_by(
            database_name="examples").one().id)
        dataset_info: Dict[str, Dict[str, Any]] = {}
        for file_name, config in configs.items():
            if file_name.startswith("datasets/"):
                config["database_id"] = examples_id
                dataset = import_dataset(session, config, overwrite=overwrite)
                dataset_info[str(dataset.uuid)] = {
                    "datasource_id": dataset.id,
                    "datasource_type":
                    "view" if dataset.is_sqllab_view else "table",
                    "datasource_name": dataset.table_name,
                }

        # import charts
        chart_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("charts/"):
                # update datasource id, type, and name
                config.update(dataset_info[config["dataset_uuid"]])
                chart = import_chart(session, config, overwrite=overwrite)
                chart_ids[str(chart.uuid)] = chart.id

        # store the existing relationship between dashboards and charts
        existing_relationships = session.execute(
            select(
                [dashboard_slices.c.dashboard_id,
                 dashboard_slices.c.slice_id])).fetchall()

        # import dashboards
        dashboard_chart_ids: List[Tuple[int, int]] = []
        for file_name, config in configs.items():
            if file_name.startswith("dashboards/"):
                config = update_id_refs(config, chart_ids)
                dashboard = import_dashboard(session,
                                             config,
                                             overwrite=overwrite)
                for uuid in find_chart_uuids(config["position"]):
                    chart_id = chart_ids[uuid]
                    if (dashboard.id, chart_id) not in existing_relationships:
                        dashboard_chart_ids.append((dashboard.id, chart_id))

        # set ref in the dashboard_slices table
        values = [{
            "dashboard_id": dashboard_id,
            "slice_id": chart_id
        } for (dashboard_id, chart_id) in dashboard_chart_ids]
        # pylint: disable=no-value-for-parameter (sqlalchemy/issues/4656)
        session.execute(dashboard_slices.insert(), values)
Exemple #10
0
def test_import_column_extra_is_string(app_context: None,
                                       session: Session) -> None:
    """
    Test importing a dataset when the column extra is a string.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    yaml_config: Dict[str, Any] = {
        "version":
        "1.0.0",
        "table_name":
        "my_table",
        "main_dttm_col":
        "ds",
        "description":
        "This is the description",
        "default_endpoint":
        None,
        "offset":
        -8,
        "cache_timeout":
        3600,
        "schema":
        "my_schema",
        "sql":
        None,
        "params": {
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        },
        "template_params": {
            "answer": "42",
        },
        "filter_select_enabled":
        True,
        "fetch_values_predicate":
        "foo IN (1, 2)",
        "extra":
        '{"warning_markdown": "*WARNING*"}',
        "uuid":
        dataset_uuid,
        "metrics": [{
            "metric_name": "cnt",
            "verbose_name": None,
            "metric_type": None,
            "expression": "COUNT(*)",
            "description": None,
            "d3format": None,
            "extra": '{"warning_markdown": null}',
            "warning_text": None,
        }],
        "columns": [{
            "column_name": "profit",
            "verbose_name": None,
            "is_dttm": False,
            "is_active": True,
            "type": "INTEGER",
            "groupby": False,
            "filterable": False,
            "expression": "revenue-expenses",
            "description": None,
            "python_date_format": None,
            "extra": '{"certified_by": "User"}',
        }],
        "database_uuid":
        database.uuid,
    }

    # the Marshmallow schema should convert strings to objects
    schema = ImportV1DatasetSchema()
    dataset_config = schema.load(yaml_config)
    dataset_config["database_id"] = database.id
    sqla_table = import_dataset(session, dataset_config)

    assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
    assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
    assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
Exemple #11
0
def test_import_dataset(app_context: None, session: Session) -> None:
    """
    Test importing a dataset.
    """
    from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
    from superset.datasets.commands.importers.v1.utils import import_dataset
    from superset.datasets.schemas import ImportV1DatasetSchema
    from superset.models.core import Database

    engine = session.get_bind()
    SqlaTable.metadata.create_all(engine)  # pylint: disable=no-member

    database = Database(database_name="my_database",
                        sqlalchemy_uri="sqlite://")
    session.add(database)
    session.flush()

    dataset_uuid = uuid.uuid4()
    config = {
        "table_name":
        "my_table",
        "main_dttm_col":
        "ds",
        "description":
        "This is the description",
        "default_endpoint":
        None,
        "offset":
        -8,
        "cache_timeout":
        3600,
        "schema":
        "my_schema",
        "sql":
        None,
        "params": {
            "remote_id": 64,
            "database_name": "examples",
            "import_time": 1606677834,
        },
        "template_params": {
            "answer": "42",
        },
        "filter_select_enabled":
        True,
        "fetch_values_predicate":
        "foo IN (1, 2)",
        "extra": {
            "warning_markdown": "*WARNING*"
        },
        "uuid":
        dataset_uuid,
        "metrics": [{
            "metric_name": "cnt",
            "verbose_name": None,
            "metric_type": None,
            "expression": "COUNT(*)",
            "description": None,
            "d3format": None,
            "extra": {
                "warning_markdown": None
            },
            "warning_text": None,
        }],
        "columns": [{
            "column_name": "profit",
            "verbose_name": None,
            "is_dttm": None,
            "is_active": None,
            "type": "INTEGER",
            "groupby": None,
            "filterable": None,
            "expression": "revenue-expenses",
            "description": None,
            "python_date_format": None,
            "extra": {
                "certified_by": "User",
            },
        }],
        "database_uuid":
        database.uuid,
        "database_id":
        database.id,
    }

    sqla_table = import_dataset(session, config)
    assert sqla_table.table_name == "my_table"
    assert sqla_table.main_dttm_col == "ds"
    assert sqla_table.description == "This is the description"
    assert sqla_table.default_endpoint is None
    assert sqla_table.offset == -8
    assert sqla_table.cache_timeout == 3600
    assert sqla_table.schema == "my_schema"
    assert sqla_table.sql is None
    assert sqla_table.params == json.dumps({
        "remote_id": 64,
        "database_name": "examples",
        "import_time": 1606677834
    })
    assert sqla_table.template_params == json.dumps({"answer": "42"})
    assert sqla_table.filter_select_enabled is True
    assert sqla_table.fetch_values_predicate == "foo IN (1, 2)"
    assert sqla_table.extra == '{"warning_markdown": "*WARNING*"}'
    assert sqla_table.uuid == dataset_uuid
    assert len(sqla_table.metrics) == 1
    assert sqla_table.metrics[0].metric_name == "cnt"
    assert sqla_table.metrics[0].verbose_name is None
    assert sqla_table.metrics[0].metric_type is None
    assert sqla_table.metrics[0].expression == "COUNT(*)"
    assert sqla_table.metrics[0].description is None
    assert sqla_table.metrics[0].d3format is None
    assert sqla_table.metrics[0].extra == '{"warning_markdown": null}'
    assert sqla_table.metrics[0].warning_text is None
    assert len(sqla_table.columns) == 1
    assert sqla_table.columns[0].column_name == "profit"
    assert sqla_table.columns[0].verbose_name is None
    assert sqla_table.columns[0].is_dttm is False
    assert sqla_table.columns[0].is_active is True
    assert sqla_table.columns[0].type == "INTEGER"
    assert sqla_table.columns[0].groupby is True
    assert sqla_table.columns[0].filterable is True
    assert sqla_table.columns[0].expression == "revenue-expenses"
    assert sqla_table.columns[0].description is None
    assert sqla_table.columns[0].python_date_format is None
    assert sqla_table.columns[0].extra == '{"certified_by": "User"}'
    assert sqla_table.database.uuid == database.uuid
    assert sqla_table.database.id == database.id
Exemple #12
0
    def _import(
        session: Session,
        configs: Dict[str, Any],
        overwrite: bool = False,
        force_data: bool = False,
    ) -> None:
        # import databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("databases/"):
                database = import_database(session, config, overwrite=overwrite)
                database_ids[str(database.uuid)] = database.id

        # import datasets
        # If database_uuid is not in the list of UUIDs it means that the examples
        # database was created before its UUID was frozen, so it has a random UUID.
        # We need to determine its ID so we can point the dataset to it.
        examples_db = (
            db.session.query(Database).filter_by(database_name="examples").first()
        )
        dataset_info: Dict[str, Dict[str, Any]] = {}
        for file_name, config in configs.items():
            if file_name.startswith("datasets/"):
                # find the ID of the corresponding database
                if config["database_uuid"] not in database_ids:
                    if examples_db is None:
                        raise Exception("Cannot find examples database")
                    config["database_id"] = examples_db.id
                else:
                    config["database_id"] = database_ids[config["database_uuid"]]

                dataset = import_dataset(
                    session, config, overwrite=overwrite, force_data=force_data
                )
                dataset_info[str(dataset.uuid)] = {
                    "datasource_id": dataset.id,
                    "datasource_type": "view" if dataset.is_sqllab_view else "table",
                    "datasource_name": dataset.table_name,
                }

        # import charts
        chart_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("charts/"):
                # update datasource id, type, and name
                config.update(dataset_info[config["dataset_uuid"]])
                chart = import_chart(session, config, overwrite=overwrite)
                chart_ids[str(chart.uuid)] = chart.id

        # store the existing relationship between dashboards and charts
        existing_relationships = session.execute(
            select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
        ).fetchall()

        # import dashboards
        dashboard_chart_ids: List[Tuple[int, int]] = []
        for file_name, config in configs.items():
            if file_name.startswith("dashboards/"):
                config = update_id_refs(config, chart_ids, dataset_info)
                dashboard = import_dashboard(session, config, overwrite=overwrite)
                dashboard.published = True

                for uuid in find_chart_uuids(config["position"]):
                    chart_id = chart_ids[uuid]
                    if (dashboard.id, chart_id) not in existing_relationships:
                        dashboard_chart_ids.append((dashboard.id, chart_id))

        # set ref in the dashboard_slices table
        values = [
            {"dashboard_id": dashboard_id, "slice_id": chart_id}
            for (dashboard_id, chart_id) in dashboard_chart_ids
        ]
        # pylint: disable=no-value-for-parameter # sqlalchemy/issues/4656
        session.execute(dashboard_slices.insert(), values)
Exemple #13
0
    def _import(  # pylint: disable=arguments-differ, too-many-locals, too-many-branches
        session: Session,
        configs: Dict[str, Any],
        overwrite: bool = False,
        force_data: bool = False,
    ) -> None:
        # import databases
        database_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if file_name.startswith("databases/"):
                database = import_database(session, config, overwrite=overwrite)
                database_ids[str(database.uuid)] = database.id

        # import datasets
        # If database_uuid is not in the list of UUIDs it means that the examples
        # database was created before its UUID was frozen, so it has a random UUID.
        # We need to determine its ID so we can point the dataset to it.
        examples_db = get_example_database()
        dataset_info: Dict[str, Dict[str, Any]] = {}
        for file_name, config in configs.items():
            if file_name.startswith("datasets/"):
                # find the ID of the corresponding database
                if config["database_uuid"] not in database_ids:
                    if examples_db is None:
                        raise Exception("Cannot find examples database")
                    config["database_id"] = examples_db.id
                else:
                    config["database_id"] = database_ids[config["database_uuid"]]

                # set schema
                if config["schema"] is None:
                    config["schema"] = get_example_default_schema()

                dataset = import_dataset(
                    session, config, overwrite=overwrite, force_data=force_data
                )

                try:
                    dataset = import_dataset(
                        session, config, overwrite=overwrite, force_data=force_data
                    )
                except MultipleResultsFound:
                    # Multiple result can be found for datasets. There was a bug in
                    # load-examples that resulted in datasets being loaded with a NULL
                    # schema. Users could then add a new dataset with the same name in
                    # the correct schema, resulting in duplicates, since the uniqueness
                    # constraint was not enforced correctly in the application logic.
                    # See https://github.com/apache/superset/issues/16051.
                    continue

                dataset_info[str(dataset.uuid)] = {
                    "datasource_id": dataset.id,
                    "datasource_type": "table",
                    "datasource_name": dataset.table_name,
                }

        # import charts
        chart_ids: Dict[str, int] = {}
        for file_name, config in configs.items():
            if (
                file_name.startswith("charts/")
                and config["dataset_uuid"] in dataset_info
            ):
                # update datasource id, type, and name
                config.update(dataset_info[config["dataset_uuid"]])
                chart = import_chart(session, config, overwrite=overwrite)
                chart_ids[str(chart.uuid)] = chart.id

        # store the existing relationship between dashboards and charts
        existing_relationships = session.execute(
            select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
        ).fetchall()

        # import dashboards
        dashboard_chart_ids: List[Tuple[int, int]] = []
        for file_name, config in configs.items():
            if file_name.startswith("dashboards/"):
                try:
                    config = update_id_refs(config, chart_ids, dataset_info)
                except KeyError:
                    continue

                dashboard = import_dashboard(session, config, overwrite=overwrite)
                dashboard.published = True

                for uuid in find_chart_uuids(config["position"]):
                    chart_id = chart_ids[uuid]
                    if (dashboard.id, chart_id) not in existing_relationships:
                        dashboard_chart_ids.append((dashboard.id, chart_id))

        # set ref in the dashboard_slices table
        values = [
            {"dashboard_id": dashboard_id, "slice_id": chart_id}
            for (dashboard_id, chart_id) in dashboard_chart_ids
        ]
        session.execute(dashboard_slices.insert(), values)