Beispiel #1
0
class MetabaseSource(Source):
    """
    This plugin extracts Charts, dashboards, and associated metadata. This plugin is in beta and has only been tested
    on PostgreSQL and H2 database.
    ### Dashboard

    [/api/dashboard](https://www.metabase.com/docs/latest/api-documentation.html#dashboard) endpoint is used to
    retrieve the following dashboard information.

    - Title and description
    - Last edited by
    - Owner
    - Link to the dashboard in Metabase
    - Associated charts

    ### Chart

    [/api/card](https://www.metabase.com/docs/latest/api-documentation.html#card) endpoint is used to
    retrieve the following information.

    - Title and description
    - Last edited by
    - Owner
    - Link to the chart in Metabase
    - Datasource and lineage

    The following properties for a chart are ingested in DataHub.

    | Name          | Description                                     |
    | ------------- | ----------------------------------------------- |
    | `Dimensions`  | Column names                                    |
    | `Filters`     | Any filters applied to the chart                |
    | `Metrics`     | All columns that are being used for aggregation |


    """

    config: MetabaseConfig
    report: SourceReport
    platform = "metabase"

    def __hash__(self):
        return id(self)

    def __init__(self, ctx: PipelineContext, config: MetabaseConfig):
        super().__init__(ctx)
        self.config = config
        self.report = SourceReport()

        login_response = requests.post(
            f"{self.config.connect_uri}/api/session",
            None,
            {
                "username": self.config.username,
                "password": self.config.password,
            },
        )

        login_response.raise_for_status()
        self.access_token = login_response.json().get("id", "")

        self.session = requests.session()
        self.session.headers.update({
            "X-Metabase-Session": f"{self.access_token}",
            "Content-Type": "application/json",
            "Accept": "*/*",
        })

        # Test the connection
        try:
            test_response = self.session.get(
                f"{self.config.connect_uri}/api/user/current")
            test_response.raise_for_status()
        except HTTPError as e:
            self.report.report_failure(
                key="metabase-session",
                reason=
                f"Unable to retrieve user {self.config.username} information. %s"
                % str(e),
            )

    def close(self) -> None:
        response = requests.delete(
            f"{self.config.connect_uri}/api/session",
            headers={"X-Metabase-Session": self.access_token},
        )
        if response.status_code not in (200, 204):
            self.report.report_failure(
                key="metabase-session",
                reason=f"Unable to logout for user {self.config.username}",
            )

    def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
        try:
            dashboard_response = self.session.get(
                f"{self.config.connect_uri}/api/dashboard")
            dashboard_response.raise_for_status()
            dashboards = dashboard_response.json()

            for dashboard_info in dashboards:
                dashboard_snapshot = self.construct_dashboard_from_api_data(
                    dashboard_info)
                if dashboard_snapshot is not None:
                    mce = MetadataChangeEvent(
                        proposedSnapshot=dashboard_snapshot)
                    wu = MetadataWorkUnit(id=dashboard_snapshot.urn, mce=mce)
                    self.report.report_workunit(wu)
                    yield wu

        except HTTPError as http_error:
            self.report.report_failure(
                key="metabase-dashboard",
                reason=f"Unable to retrieve dashboards. "
                f"Reason: {str(http_error)}",
            )

    @staticmethod
    def get_timestamp_millis_from_ts_string(ts_str: str) -> int:
        """
        Converts the given timestamp string to milliseconds. If parsing fails,
        returns the utc-now in milliseconds.
        """
        try:
            return int(dp.parse(ts_str).timestamp() * 1000)
        except (dp.ParserError, OverflowError):
            return int(datetime.utcnow().timestamp() * 1000)

    def construct_dashboard_from_api_data(
            self, dashboard_info: dict) -> Optional[DashboardSnapshot]:

        dashboard_id = dashboard_info.get("id", "")
        dashboard_url = f"{self.config.connect_uri}/api/dashboard/{dashboard_id}"
        try:
            dashboard_response = self.session.get(dashboard_url)
            dashboard_response.raise_for_status()
            dashboard_details = dashboard_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-dashboard-{dashboard_id}",
                reason=f"Unable to retrieve dashboard. "
                f"Reason: {str(http_error)}",
            )
            return None

        dashboard_urn = builder.make_dashboard_urn(
            self.platform, dashboard_details.get("id", ""))
        dashboard_snapshot = DashboardSnapshot(
            urn=dashboard_urn,
            aspects=[],
        )
        last_edit_by = dashboard_details.get("last-edit-info") or {}
        modified_actor = builder.make_user_urn(
            last_edit_by.get("email", "unknown"))
        modified_ts = self.get_timestamp_millis_from_ts_string(
            f"{last_edit_by.get('timestamp')}")
        title = dashboard_details.get("name", "") or ""
        description = dashboard_details.get("description", "") or ""
        last_modified = ChangeAuditStamps(
            created=AuditStamp(time=modified_ts, actor=modified_actor),
            lastModified=AuditStamp(time=modified_ts, actor=modified_actor),
        )

        chart_urns = []
        cards_data = dashboard_details.get("ordered_cards", "{}")
        for card_info in cards_data:
            chart_urn = builder.make_chart_urn(self.platform,
                                               card_info.get("id", ""))
            chart_urns.append(chart_urn)

        dashboard_info_class = DashboardInfoClass(
            description=description,
            title=title,
            charts=chart_urns,
            lastModified=last_modified,
            dashboardUrl=f"{self.config.connect_uri}/dashboard/{dashboard_id}",
            customProperties={},
        )
        dashboard_snapshot.aspects.append(dashboard_info_class)

        # Ownership
        ownership = self._get_ownership(dashboard_details.get(
            "creator_id", ""))
        if ownership is not None:
            dashboard_snapshot.aspects.append(ownership)

        return dashboard_snapshot

    @lru_cache(maxsize=None)
    def _get_ownership(self, creator_id: int) -> Optional[OwnershipClass]:
        user_info_url = f"{self.config.connect_uri}/api/user/{creator_id}"
        try:
            user_info_response = self.session.get(user_info_url)
            user_info_response.raise_for_status()
            user_details = user_info_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-user-{creator_id}",
                reason=f"Unable to retrieve User info. "
                f"Reason: {str(http_error)}",
            )
            return None

        owner_urn = builder.make_user_urn(user_details.get("email", ""))
        if owner_urn is not None:
            ownership: OwnershipClass = OwnershipClass(owners=[
                OwnerClass(
                    owner=owner_urn,
                    type=OwnershipTypeClass.DATAOWNER,
                )
            ])
            return ownership

        return None

    def emit_card_mces(self) -> Iterable[MetadataWorkUnit]:
        try:
            card_response = self.session.get(
                f"{self.config.connect_uri}/api/card")
            card_response.raise_for_status()
            cards = card_response.json()

            for card_info in cards:
                chart_snapshot = self.construct_card_from_api_data(card_info)
                if chart_snapshot is not None:
                    mce = MetadataChangeEvent(proposedSnapshot=chart_snapshot)
                    wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
                    self.report.report_workunit(wu)
                    yield wu

        except HTTPError as http_error:
            self.report.report_failure(
                key="metabase-cards",
                reason=f"Unable to retrieve cards. "
                f"Reason: {str(http_error)}",
            )
            return None

    def construct_card_from_api_data(
            self, card_data: dict) -> Optional[ChartSnapshot]:
        card_id = card_data.get("id", "")
        card_url = f"{self.config.connect_uri}/api/card/{card_id}"
        try:
            card_response = self.session.get(card_url)
            card_response.raise_for_status()
            card_details = card_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-card-{card_id}",
                reason=f"Unable to retrieve Card info. "
                f"Reason: {str(http_error)}",
            )
            return None

        chart_urn = builder.make_chart_urn(self.platform, card_id)
        chart_snapshot = ChartSnapshot(
            urn=chart_urn,
            aspects=[],
        )

        last_edit_by = card_details.get("last-edit-info") or {}
        modified_actor = builder.make_user_urn(
            last_edit_by.get("email", "unknown"))
        modified_ts = self.get_timestamp_millis_from_ts_string(
            f"{last_edit_by.get('timestamp')}")
        last_modified = ChangeAuditStamps(
            created=AuditStamp(time=modified_ts, actor=modified_actor),
            lastModified=AuditStamp(time=modified_ts, actor=modified_actor),
        )

        chart_type = self._get_chart_type(card_details.get("id", ""),
                                          card_details.get("display"))
        description = card_details.get("description") or ""
        title = card_details.get("name") or ""
        datasource_urn = self.get_datasource_urn(card_details)
        custom_properties = self.construct_card_custom_properties(card_details)

        chart_info = ChartInfoClass(
            type=chart_type,
            description=description,
            title=title,
            lastModified=last_modified,
            chartUrl=f"{self.config.connect_uri}/card/{card_id}",
            inputs=datasource_urn,
            customProperties=custom_properties,
        )
        chart_snapshot.aspects.append(chart_info)

        if card_details.get("query_type", "") == "native":
            raw_query = (card_details.get("dataset_query",
                                          {}).get("native",
                                                  {}).get("query", ""))
            chart_query_native = ChartQueryClass(
                rawQuery=raw_query,
                type=ChartQueryTypeClass.SQL,
            )
            chart_snapshot.aspects.append(chart_query_native)

        # Ownership
        ownership = self._get_ownership(card_details.get("creator_id", ""))
        if ownership is not None:
            chart_snapshot.aspects.append(ownership)

        return chart_snapshot

    def _get_chart_type(self, card_id: int,
                        display_type: str) -> Optional[str]:
        type_mapping = {
            "table": ChartTypeClass.TABLE,
            "bar": ChartTypeClass.BAR,
            "line": ChartTypeClass.LINE,
            "row": ChartTypeClass.BAR,
            "area": ChartTypeClass.AREA,
            "pie": ChartTypeClass.PIE,
            "funnel": ChartTypeClass.BAR,
            "scatter": ChartTypeClass.SCATTER,
            "scalar": ChartTypeClass.TEXT,
            "smartscalar": ChartTypeClass.TEXT,
            "pivot": ChartTypeClass.TABLE,
            "waterfall": ChartTypeClass.BAR,
            "progress": None,
            "combo": None,
            "gauge": None,
            "map": None,
        }
        if not display_type:
            self.report.report_warning(
                key=f"metabase-card-{card_id}",
                reason=f"Card type {display_type} is missing. Setting to None",
            )
            return None
        try:
            chart_type = type_mapping[display_type]
        except KeyError:
            self.report.report_warning(
                key=f"metabase-card-{card_id}",
                reason=
                f"Chart type {display_type} not supported. Setting to None",
            )
            chart_type = None

        return chart_type

    def construct_card_custom_properties(self, card_details: dict) -> Dict:
        result_metadata = card_details.get("result_metadata") or []
        metrics, dimensions = [], []
        for meta_data in result_metadata:
            display_name = meta_data.get("display_name", "") or ""
            metrics.append(display_name) if "aggregation" in meta_data.get(
                "field_ref", "") else dimensions.append(display_name)

        filters = (card_details.get("dataset_query",
                                    {}).get("query", {})).get("filter", [])

        custom_properties = {
            "Metrics": ", ".join(metrics),
            "Filters": f"{filters}" if len(filters) else "",
            "Dimensions": ", ".join(dimensions),
        }

        return custom_properties

    def get_datasource_urn(self, card_details):
        platform, database_name, platform_instance = self.get_datasource_from_id(
            card_details.get("database_id", ""))
        query_type = card_details.get("dataset_query", {}).get("type", {})
        source_paths = set()

        if query_type == "query":
            source_table_id = (card_details.get("dataset_query", {}).get(
                "query", {}).get("source-table"))
            if source_table_id is not None:
                schema_name, table_name = self.get_source_table_from_id(
                    source_table_id)
                if table_name:
                    source_paths.add(
                        f"{schema_name + '.' if schema_name else ''}{table_name}"
                    )
        else:
            try:
                raw_query = (card_details.get("dataset_query",
                                              {}).get("native",
                                                      {}).get("query", ""))
                parser = LineageRunner(raw_query)

                for table in parser.source_tables:
                    sources = str(table).split(".")
                    source_schema, source_table = sources[-2], sources[-1]
                    if source_schema == "<default>":
                        source_schema = str(self.config.default_schema)

                    source_paths.add(f"{source_schema}.{source_table}")
            except Exception as e:
                self.report.report_failure(
                    key="metabase-query",
                    reason=f"Unable to retrieve lineage from query. "
                    f"Query: {raw_query} "
                    f"Reason: {str(e)} ",
                )
                return None

        # Create dataset URNs
        dataset_urn = []
        dbname = f"{database_name + '.' if database_name else ''}"
        source_tables = list(map(lambda tbl: f"{dbname}{tbl}", source_paths))
        dataset_urn = [
            builder.make_dataset_urn_with_platform_instance(
                platform=platform,
                name=name,
                platform_instance=platform_instance,
                env=self.config.env,
            ) for name in source_tables
        ]

        return dataset_urn

    @lru_cache(maxsize=None)
    def get_source_table_from_id(self, table_id):
        try:
            dataset_response = self.session.get(
                f"{self.config.connect_uri}/api/table/{table_id}")
            dataset_response.raise_for_status()
            dataset_json = dataset_response.json()
            schema = dataset_json.get("schema", "")
            name = dataset_json.get("name", "")
            return schema, name

        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-table-{table_id}",
                reason=f"Unable to retrieve source table. "
                f"Reason: {str(http_error)}",
            )

        return None, None

    @lru_cache(maxsize=None)
    def get_datasource_from_id(self, datasource_id):
        try:
            dataset_response = self.session.get(
                f"{self.config.connect_uri}/api/database/{datasource_id}")
            dataset_response.raise_for_status()
            dataset_json = dataset_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-datasource-{datasource_id}",
                reason=f"Unable to retrieve Datasource. "
                f"Reason: {str(http_error)}",
            )
            return None, None

        # Map engine names to what datahub expects in
        # https://github.com/datahub-project/datahub/blob/master/metadata-service/war/src/main/resources/boot/data_platforms.json
        engine = dataset_json.get("engine", "")
        platform = engine

        engine_mapping = {
            "sparksql": "spark",
            "mongo": "mongodb",
            "presto-jdbc": "presto",
            "sqlserver": "mssql",
            "bigquery-cloud-sdk": "bigquery",
        }

        if self.config.engine_platform_map is not None:
            engine_mapping.update(self.config.engine_platform_map)

        if engine in engine_mapping:
            platform = engine_mapping[engine]
        else:
            self.report.report_warning(
                key=f"metabase-platform-{datasource_id}",
                reason=
                f"Platform was not found in DataHub. Using {platform} name as is",
            )
        # Set platform_instance if configuration provides a mapping from platform name to instance
        platform_instance = (self.config.platform_instance_map.get(platform)
                             if self.config.platform_instance_map else None)

        field_for_dbname_mapping = {
            "postgres": "dbname",
            "sparksql": "dbname",
            "mongo": "dbname",
            "redshift": "db",
            "snowflake": "db",
            "presto-jdbc": "catalog",
            "presto": "catalog",
            "mysql": "dbname",
            "sqlserver": "db",
        }

        dbname = (dataset_json.get("details", {}).get(
            field_for_dbname_mapping[engine])
                  if engine in field_for_dbname_mapping else None)

        if (self.config.database_alias_map is not None
                and platform in self.config.database_alias_map):
            dbname = self.config.database_alias_map[platform]
        else:
            self.report.report_warning(
                key=f"metabase-dbname-{datasource_id}",
                reason=
                f"Cannot determine database name for platform: {platform}",
            )

        return platform, dbname, platform_instance

    @classmethod
    def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
        config = MetabaseConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        yield from self.emit_dashboard_mces()
        yield from self.emit_card_mces()

    def get_report(self) -> SourceReport:
        return self.report
Beispiel #2
0
class ModeSource(Source):
    config: ModeConfig
    report: SourceReport
    platform = "mode"

    def __hash__(self):
        return id(self)

    def __init__(self, ctx: PipelineContext, config: ModeConfig):
        super().__init__(ctx)
        self.config = config
        self.report = SourceReport()

        self.session = requests.session()
        self.session.auth = HTTPBasicAuth(self.config.token,
                                          self.config.password)
        self.session.headers.update({
            "Content-Type": "application/json",
            "Accept": "application/hal+json",
        })

        # Test the connection
        try:
            self._get_request_json(f"{self.config.connect_uri}/api/account")
        except HTTPError as http_error:
            self.report.report_failure(
                key="mode-session",
                reason=f"Unable to retrieve user "
                f"{self.config.token} information, "
                f"{str(http_error)}",
            )

        self.workspace_uri = f"{self.config.connect_uri}/api/{self.config.workspace}"
        self.space_tokens = self._get_space_name_and_tokens()

    def construct_dashboard(self, space_name: str,
                            report_info: dict) -> DashboardSnapshot:
        report_token = report_info.get("token", "")
        dashboard_urn = builder.make_dashboard_urn(self.platform,
                                                   report_info.get("id", ""))
        dashboard_snapshot = DashboardSnapshot(
            urn=dashboard_urn,
            aspects=[],
        )

        last_modified = ChangeAuditStamps()
        creator = self._get_creator(
            report_info.get("_links", {}).get("creator", {}).get("href", ""))
        if creator is not None:
            modified_actor = builder.make_user_urn(creator)
            modified_ts = int(
                dp.parse(
                    f"{report_info.get('last_saved_at', 'now')}").timestamp() *
                1000)
            created_ts = int(
                dp.parse(
                    f"{report_info.get('created_at', 'now')}").timestamp() *
                1000)
            title = report_info.get("name", "") or ""
            description = report_info.get("description", "") or ""
            last_modified = ChangeAuditStamps(
                created=AuditStamp(time=created_ts, actor=modified_actor),
                lastModified=AuditStamp(time=modified_ts,
                                        actor=modified_actor),
            )

        dashboard_info_class = DashboardInfoClass(
            description=description,
            title=title,
            charts=self._get_chart_urns(report_token),
            lastModified=last_modified,
            dashboardUrl=f"{self.config.connect_uri}/"
            f"{self.config.workspace}/"
            f"reports/{report_token}",
            customProperties={},
        )
        dashboard_snapshot.aspects.append(dashboard_info_class)

        # browse path
        browse_path = BrowsePathsClass(paths=[
            f"/mode/{self.config.workspace}/"
            f"{space_name}/"
            f"{report_info.get('name')}"
        ])
        dashboard_snapshot.aspects.append(browse_path)

        # Ownership
        ownership = self._get_ownership(
            self._get_creator(
                report_info.get("_links", {}).get("creator",
                                                  {}).get("href", "")))
        if ownership is not None:
            dashboard_snapshot.aspects.append(ownership)

        return dashboard_snapshot

    @lru_cache(maxsize=None)
    def _get_ownership(self, user: str) -> Optional[OwnershipClass]:
        if user is not None:
            owner_urn = builder.make_user_urn(user)
            ownership: OwnershipClass = OwnershipClass(owners=[
                OwnerClass(
                    owner=owner_urn,
                    type=OwnershipTypeClass.DATAOWNER,
                )
            ])
            return ownership

        return None

    @lru_cache(maxsize=None)
    def _get_creator(self, href: str) -> Optional[str]:
        user = None
        try:
            user_json = self._get_request_json(
                f"{self.config.connect_uri}{href}")
            user = (user_json.get("username")
                    if self.config.owner_username_instead_of_email else
                    user_json.get("email"))
        except HTTPError as http_error:
            self.report.report_failure(
                key="mode-user",
                reason=f"Unable to retrieve user for {href}, "
                f"Reason: {str(http_error)}",
            )
        return user

    def _get_chart_urns(self, report_token: str) -> list:
        chart_urns = []
        queries = self._get_queries(report_token)
        for query in queries:
            charts = self._get_charts(report_token, query.get("token", ""))
            # build chart urns
            for chart in charts:
                chart_urn = builder.make_chart_urn(self.platform,
                                                   chart.get("token", ""))
                chart_urns.append(chart_urn)

        return chart_urns

    def _get_space_name_and_tokens(self) -> dict:
        space_info = {}
        try:
            payload = self._get_request_json(f"{self.workspace_uri}/spaces")
            spaces = payload.get("_embedded", {}).get("spaces", {})

            for s in spaces:
                space_info[s.get("token", "")] = s.get("name", "")
        except HTTPError as http_error:
            self.report.report_failure(
                key="mode-spaces",
                reason=
                f"Unable to retrieve spaces/collections for {self.workspace_uri}, "
                f"Reason: {str(http_error)}",
            )

        return space_info

    def _get_chart_type(self, token: str, display_type: str) -> Optional[str]:
        type_mapping = {
            "table": ChartTypeClass.TABLE,
            "bar": ChartTypeClass.BAR,
            "line": ChartTypeClass.LINE,
            "stackedBar100": ChartTypeClass.BAR,
            "stackedBar": ChartTypeClass.BAR,
            "hStackedBar": ChartTypeClass.BAR,
            "hStackedBar100": ChartTypeClass.BAR,
            "hBar": ChartTypeClass.BAR,
            "area": ChartTypeClass.AREA,
            "totalArea": ChartTypeClass.AREA,
            "pie": ChartTypeClass.PIE,
            "donut": ChartTypeClass.PIE,
            "scatter": ChartTypeClass.SCATTER,
            "bigValue": ChartTypeClass.TEXT,
            "pivotTable": ChartTypeClass.TABLE,
            "linePlusBar": None,
        }
        if not display_type:
            self.report.report_warning(
                key=f"mode-chart-{token}",
                reason=f"Chart type {display_type} is missing. "
                f"Setting to None",
            )
            return None
        try:
            chart_type = type_mapping[display_type]
        except KeyError:
            self.report.report_warning(
                key=f"mode-chart-{token}",
                reason=f"Chart type {display_type} not supported. "
                f"Setting to None",
            )
            chart_type = None

        return chart_type

    def construct_chart_custom_properties(self, chart_detail: dict,
                                          chart_type: str) -> Dict:

        custom_properties = {}
        metadata = chart_detail.get("encoding", {})
        if chart_type == "table":
            columns = list(chart_detail.get("fieldFormats", {}).keys())
            str_columns = ",".join([c[1:-1] for c in columns])
            filters = metadata.get("filter", [])
            filters = filters[0].get("formula", "") if len(filters) else ""

            custom_properties = {
                "Columns": str_columns,
                "Filters": filters[1:-1] if len(filters) else "",
            }

        elif chart_type == "pivotTable":
            pivot_table = chart_detail.get("pivotTable", {})
            columns = pivot_table.get("columns", [])
            rows = pivot_table.get("rows", [])
            values = pivot_table.get("values", [])
            filters = pivot_table.get("filters", [])

            custom_properties = {
                "Columns": ", ".join(columns) if len(columns) else "",
                "Rows": ", ".join(rows) if len(rows) else "",
                "Metrics": ", ".join(values) if len(values) else "",
                "Filters": ", ".join(filters) if len(filters) else "",
            }
            # list filters in their own row
            for filter in filters:
                custom_properties[f"Filter: {filter}"] = ", ".join(
                    pivot_table.get("filterValues", {}).get(filter, ""))
        # Chart
        else:
            x = metadata.get("x", [])
            x2 = metadata.get("x2", [])
            y = metadata.get("y", [])
            y2 = metadata.get("y2", [])
            value = metadata.get("value", [])
            filters = metadata.get("filter", [])

            custom_properties = {
                "X": x[0].get("formula", "") if len(x) else "",
                "Y": y[0].get("formula", "") if len(y) else "",
                "X2": x2[0].get("formula", "") if len(x2) else "",
                "Y2": y2[0].get("formula", "") if len(y2) else "",
                "Metrics": value[0].get("formula", "") if len(value) else "",
                "Filters":
                filters[0].get("formula", "") if len(filters) else "",
            }

        return custom_properties

    def _get_datahub_friendly_platform(self, adapter, platform):
        # Map adaptor names to what datahub expects in
        # https://github.com/linkedin/datahub/blob/master/metadata-service/war/src/main/resources/boot/data_platforms.json

        platform_mapping = {
            "jdbc:athena": "athena",
            "jdbc:bigquery": "bigquery",
            "jdbc:druid": "druid",
            "jdbc:hive": "hive",
            "jdbc:mysql": "mysql",
            "jdbc:oracle": "oracle",
            "jdbc:postgresql": "postgres",
            "jdbc:presto": "presto",
            "jdbc:redshift": "redshift",
            "jdbc:snowflake": "snowflake",
            "jdbc:spark": "spark",
            "jdbc:sqlserver": "mssql",
            "jdbc:teradata": "teradata",
        }
        if adapter in platform_mapping:
            return platform_mapping[adapter]
        else:
            self.report.report_warning(
                key=f"mode-platform-{adapter}",
                reason=f"Platform was not found in DataHub. "
                f"Using {platform} name as is",
            )

        return platform

    @lru_cache(maxsize=None)
    def _get_platform_and_dbname(
            self,
            data_source_id: int) -> Union[Tuple[str, str], Tuple[None, None]]:

        data_sources = []
        try:
            ds_json = self._get_request_json(
                f"{self.workspace_uri}/data_sources")
            data_sources = ds_json.get("_embedded", {}).get("data_sources", [])
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-datasource-{data_source_id}",
                reason=f"No data sources found for datasource id: "
                f"{data_source_id}, "
                f"Reason: {str(http_error)}",
            )

        if not data_sources:
            self.report.report_failure(
                key=f"mode-datasource-{data_source_id}",
                reason=f"No data sources found for datasource id: "
                f"{data_source_id}",
            )
            return None, None

        for data_source in data_sources:
            if data_source.get("id", -1) == data_source_id:
                platform = self._get_datahub_friendly_platform(
                    data_source.get("adapter", ""),
                    data_source.get("name", ""))
                database = data_source.get("database", "")
                return platform, database
        else:
            self.report.report_failure(
                key=f"mode-datasource-{data_source_id}",
                reason=f"Cannot create datasource urn for datasource id: "
                f"{data_source_id}",
            )
        return None, None

    def _replace_definitions(self, raw_query: str) -> str:
        query = raw_query
        definitions = re.findall("({{[^}{]+}})", raw_query)
        for definition_variable in definitions:
            definition_name, definition_alias = self._parse_definition_name(
                definition_variable)
            definition_query = self._get_definition(definition_name)
            # if unable to retrieve definition, then replace the {{}} so that it doesn't get picked up again in recurive call
            if definition_query is not None:
                query = query.replace(
                    definition_variable,
                    f"({definition_query}) as {definition_alias}")
            else:
                query = query.replace(
                    definition_variable,
                    f"{definition_name} as {definition_alias}")
            query = self._replace_definitions(query)

        return query

    def _parse_definition_name(self,
                               definition_variable: str) -> Tuple[str, str]:
        name, alias = "", ""
        # i.e '{{ @join_on_definition as alias}}'
        name_match = re.findall("@[a-zA-z]+", definition_variable)
        if len(name_match):
            name = name_match[0][1:]
        alias_match = re.findall(
            r"as\s+\S+", definition_variable)  # i.e ['as    alias_name']
        if len(alias_match):
            alias_match = alias_match[0].split(" ")
            alias = alias_match[-1]

        return name, alias

    @lru_cache(maxsize=None)
    def _get_definition(self, definition_name):
        try:
            definition_json = self._get_request_json(
                f"{self.workspace_uri}/definitions")
            definitions = definition_json.get("_embedded",
                                              {}).get("definitions", [])
            for definition in definitions:
                if definition.get("name", "") == definition_name:
                    return definition.get("source", "")

        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-definition-{definition_name}",
                reason=f"Unable to retrieve definition for {definition_name}, "
                f"Reason: {str(http_error)}",
            )
        return None

    @lru_cache(maxsize=None)
    def _get_source_from_query(self, raw_query: str) -> set:
        query = self._replace_definitions(raw_query)
        parser = LineageRunner(query)
        source_paths = set()
        try:
            for table in parser.source_tables:
                sources = str(table).split(".")
                source_schema, source_table = sources[-2], sources[-1]
                if source_schema == "<default>":
                    source_schema = str(self.config.default_schema)

                source_paths.add(f"{source_schema}.{source_table}")
        except Exception as e:
            self.report.report_failure(
                key="mode-query",
                reason=f"Unable to retrieve lineage from query. "
                f"Query: {raw_query} "
                f"Reason: {str(e)} ",
            )

        return source_paths

    def _get_datasource_urn(self, platform, database, source_tables):
        dataset_urn = None
        if platform or database is not None:
            dataset_urn = [
                builder.make_dataset_urn(platform, f"{database}.{s_table}",
                                         self.config.env)
                for s_table in source_tables
            ]

        return dataset_urn

    def construct_chart_from_api_data(self, chart_data: dict, query: dict,
                                      path: str) -> ChartSnapshot:
        chart_urn = builder.make_chart_urn(self.platform,
                                           chart_data.get("token", ""))
        chart_snapshot = ChartSnapshot(
            urn=chart_urn,
            aspects=[],
        )

        last_modified = ChangeAuditStamps()
        creator = self._get_creator(
            chart_data.get("_links", {}).get("creator", {}).get("href", ""))
        if creator is not None:
            modified_actor = builder.make_user_urn(creator)
            created_ts = int(
                dp.parse(chart_data.get("created_at", "now")).timestamp() *
                1000)
            modified_ts = int(
                dp.parse(chart_data.get("updated_at", "now")).timestamp() *
                1000)
            last_modified = ChangeAuditStamps(
                created=AuditStamp(time=created_ts, actor=modified_actor),
                lastModified=AuditStamp(time=modified_ts,
                                        actor=modified_actor),
            )

        chart_detail = (chart_data.get("view", {})
                        if len(chart_data.get("view", {})) != 0 else
                        chart_data.get("view_vegas", {}))

        mode_chart_type = chart_detail.get(
            "chartType", "") or chart_detail.get("selectedChart", "")
        chart_type = self._get_chart_type(chart_data.get("token", ""),
                                          mode_chart_type)
        description = (chart_detail.get("description")
                       or chart_detail.get("chartDescription") or "")
        title = chart_detail.get("title") or chart_detail.get(
            "chartTitle") or ""

        # create datasource urn
        platform, db_name = self._get_platform_and_dbname(
            query.get("data_source_id"))
        source_tables = self._get_source_from_query(query.get("raw_query"))
        datasource_urn = self._get_datasource_urn(platform, db_name,
                                                  source_tables)
        custom_properties = self.construct_chart_custom_properties(
            chart_detail, mode_chart_type)

        # Chart Info
        chart_info = ChartInfoClass(
            type=chart_type,
            description=description,
            title=title,
            lastModified=last_modified,
            chartUrl=f"{self.config.connect_uri}"
            f"{chart_data.get('_links', {}).get('report_viz_web', {}).get('href', '')}",
            inputs=datasource_urn,
            customProperties=custom_properties,
        )
        chart_snapshot.aspects.append(chart_info)

        # Browse Path
        browse_path = BrowsePathsClass(paths=[path])
        chart_snapshot.aspects.append(browse_path)

        # Query
        chart_query = ChartQueryClass(
            rawQuery=query.get("raw_query", ""),
            type=ChartQueryTypeClass.SQL,
        )
        chart_snapshot.aspects.append(chart_query)

        # Ownership
        ownership = self._get_ownership(
            self._get_creator(
                chart_data.get("_links", {}).get("creator",
                                                 {}).get("href", "")))
        if ownership is not None:
            chart_snapshot.aspects.append(ownership)

        return chart_snapshot

    @lru_cache(maxsize=None)
    def _get_reports(self, space_token: str) -> list:
        reports = []
        try:
            reports_json = self._get_request_json(
                f"{self.workspace_uri}/spaces/{space_token}/reports")
            reports = reports_json.get("_embedded", {}).get("reports", {})
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-report-{space_token}",
                reason=
                f"Unable to retrieve reports for space token: {space_token}, "
                f"Reason: {str(http_error)}",
            )
        return reports

    @lru_cache(maxsize=None)
    def _get_queries(self, report_token: str) -> list:
        queries = []
        try:
            queries_json = self._get_request_json(
                f"{self.workspace_uri}/reports/{report_token}/queries")
            queries = queries_json.get("_embedded", {}).get("queries", {})
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-query-{report_token}",
                reason=
                f"Unable to retrieve queries for report token: {report_token}, "
                f"Reason: {str(http_error)}",
            )
        return queries

    @lru_cache(maxsize=None)
    def _get_charts(self, report_token: str, query_token: str) -> list:
        charts = []
        try:
            charts_json = self._get_request_json(
                f"{self.workspace_uri}/reports/{report_token}"
                f"/queries/{query_token}/charts")
            charts = charts_json.get("_embedded", {}).get("charts", {})
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-chart-{report_token}-{query_token}",
                reason=f"Unable to retrieve charts: "
                f"Report token: {report_token} "
                f"Query token: {query_token}, "
                f"Reason: {str(http_error)}",
            )
        return charts

    def _get_request_json(self, url: str) -> Dict:
        r = tenacity.Retrying(
            wait=wait_exponential(
                multiplier=self.config.api_options.retry_backoff_multiplier,
                max=self.config.api_options.max_retry_interval,
            ),
            retry=retry_if_exception_type(HTTPError429),
            stop=stop_after_attempt(self.config.api_options.max_attempts),
        )

        @r.wraps
        def get_request():
            try:
                response = self.session.get(url)
                response.raise_for_status()
                return response.json()
            except HTTPError as http_error:
                error_response = http_error.response
                if error_response.status_code == 429:
                    # respect Retry-After
                    sleep_time = error_response.headers.get("retry-after")
                    if sleep_time is not None:
                        time.sleep(sleep_time)
                    raise HTTPError429

                raise http_error
            return {}

        return get_request()

    def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
        for space_token, space_name in self.space_tokens.items():
            reports = self._get_reports(space_token)
            for report in reports:
                dashboard_snapshot_from_report = self.construct_dashboard(
                    space_name, report)

                mce = MetadataChangeEvent(
                    proposedSnapshot=dashboard_snapshot_from_report)
                wu = MetadataWorkUnit(id=dashboard_snapshot_from_report.urn,
                                      mce=mce)
                self.report.report_workunit(wu)

                yield wu

    def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
        # Space/collection -> report -> query -> Chart
        for space_token, space_name in self.space_tokens.items():
            reports = self._get_reports(space_token)
            for report in reports:
                report_token = report.get("token", "")
                queries = self._get_queries(report_token)
                for query in queries:
                    charts = self._get_charts(report_token,
                                              query.get("token", ""))
                    # build charts
                    for chart in charts:
                        view = chart.get("view") or chart.get("view_vegas")
                        chart_name = view.get("title") or view.get(
                            "chartTitle") or ""
                        path = (f"/mode/{self.config.workspace}/{space_name}"
                                f"/{report.get('name')}/{query.get('name')}/"
                                f"{chart_name}")
                        chart_snapshot = self.construct_chart_from_api_data(
                            chart, query, path)
                        mce = MetadataChangeEvent(
                            proposedSnapshot=chart_snapshot)
                        wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
                        self.report.report_workunit(wu)

                        yield wu

    @classmethod
    def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
        config = ModeConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        yield from self.emit_dashboard_mces()
        yield from self.emit_chart_mces()

    def get_report(self) -> SourceReport:
        return self.report
Beispiel #3
0
class LDAPSource(Source):
    config: LDAPSourceConfig
    report: SourceReport

    def __init__(self, ctx: PipelineContext, config: LDAPSourceConfig):
        super().__init__(ctx)
        self.config = config
        self.report = SourceReport()

        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW)
        ldap.set_option(ldap.OPT_REFERRALS, 0)

        self.ldap_client = ldap.initialize(self.config.ldap_server)
        self.ldap_client.protocol_version = 3

        try:
            self.ldap_client.simple_bind_s(
                self.config.ldap_user, self.config.ldap_password
            )
        except ldap.LDAPError as e:
            raise ConfigurationError("LDAP connection failed") from e

        self.lc = create_controls(self.config.page_size)

    @classmethod
    def create(cls, config_dict, ctx):
        config = LDAPSourceConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        cookie = True
        while cookie:
            try:
                msgid = self.ldap_client.search_ext(
                    self.config.base_dn,
                    ldap.SCOPE_SUBTREE,
                    self.config.filter,
                    serverctrls=[self.lc],
                )
                rtype, rdata, rmsgid, serverctrls = self.ldap_client.result3(msgid)
            except ldap.LDAPError as e:
                self.report.report_failure(
                    "ldap-control", "LDAP search failed: {}".format(e)
                )
                break

            for dn, attrs in rdata:
                # TODO: create groups if 'organizationalUnit' in attrs['objectClass']

                if (
                    b"inetOrgPerson" in attrs["objectClass"]
                    or b"posixAccount" in attrs["objectClass"]
                ):
                    yield from self.handle_user(dn, attrs)

            pctrls = get_pctrls(serverctrls)
            if not pctrls:
                self.report.report_failure(
                    "ldap-control", "Server ignores RFC 2696 control."
                )
                break

            cookie = set_cookie(self.lc, pctrls, self.config.page_size)

    def handle_user(self, dn, attrs) -> Iterable[MetadataWorkUnit]:
        """
        Handle a DN and attributes by adding manager info and constructing a
        work unit based on the information.
        """
        manager_ldap = None
        if "manager" in attrs:
            try:
                m_cn = attrs["manager"][0].split(b",")[0]
                manager_msgid = self.ldap_client.search_ext(
                    self.config.base_dn,
                    ldap.SCOPE_SUBTREE,
                    f"({m_cn.decode()})",
                    serverctrls=[self.lc],
                )
                m_dn, m_attrs = self.ldap_client.result3(manager_msgid)[1][0]
                manager_ldap = guess_person_ldap(m_dn, m_attrs)
            except ldap.LDAPError as e:
                self.report.report_warning(
                    dn, "manager LDAP search failed: {}".format(e)
                )

        mce = self.build_corp_user_mce(dn, attrs, manager_ldap)
        if mce:
            wu = MetadataWorkUnit(dn, mce)
            self.report.report_workunit(wu)
            yield wu
        yield from []

    def build_corp_user_mce(
        self, dn, attrs, manager_ldap
    ) -> Optional[MetadataChangeEvent]:
        """
        Create the MetadataChangeEvent via DN and attributes.
        """
        ldap = guess_person_ldap(dn, attrs)
        full_name = attrs["cn"][0].decode()
        first_name = attrs["givenName"][0].decode()
        last_name = attrs["sn"][0].decode()
        email = (attrs["mail"][0]).decode() if "mail" in attrs else None
        display_name = (
            (attrs["displayName"][0]).decode() if "displayName" in attrs else full_name
        )
        department = (
            (attrs["departmentNumber"][0]).decode()
            if "departmentNumber" in attrs
            else None
        )
        title = attrs["title"][0].decode() if "title" in attrs else None
        manager_urn = f"urn:li:corpuser:{manager_ldap}" if manager_ldap else None

        mce = MetadataChangeEvent(
            proposedSnapshot=CorpUserSnapshotClass(
                urn=f"urn:li:corpuser:{ldap}",
                aspects=[
                    CorpUserInfoClass(
                        active=True,
                        email=email,
                        fullName=full_name,
                        firstName=first_name,
                        lastName=last_name,
                        departmentName=department,
                        displayName=display_name,
                        title=title,
                        managerUrn=manager_urn,
                    )
                ],
            )
        )

        return mce

    def get_report(self):
        return self.report

    def close(self):
        self.ldap_client.unbind()
Beispiel #4
0
class TableauSource(Source):
    config: TableauConfig
    report: SourceReport
    platform = "tableau"
    server: Server
    upstream_tables: Dict[str, Tuple[Any, str]] = {}

    def __hash__(self):
        return id(self)

    def __init__(self, ctx: PipelineContext, config: TableauConfig):
        super().__init__(ctx)

        self.config = config
        self.report = SourceReport()
        self.server = None
        # This list keeps track of datasource being actively used by workbooks so that we only retrieve those
        # when emitting published data sources.
        self.datasource_ids_being_used: List[str] = []
        # This list keeps track of datasource being actively used by workbooks so that we only retrieve those
        # when emitting custom SQL data sources.
        self.custom_sql_ids_being_used: List[str] = []

        self._authenticate()

    def close(self) -> None:
        if self.server is not None:
            self.server.auth.sign_out()

    def _authenticate(self):
        # https://tableau.github.io/server-client-python/docs/api-ref#authentication
        authentication = None
        if self.config.username and self.config.password:
            authentication = TableauAuth(
                username=self.config.username,
                password=self.config.password,
                site_id=self.config.site,
            )
        elif self.config.token_name and self.config.token_value:
            authentication = PersonalAccessTokenAuth(self.config.token_name,
                                                     self.config.token_value,
                                                     self.config.site)
        else:
            raise ConfigurationError(
                "Tableau Source: Either username/password or token_name/token_value must be set"
            )

        try:
            self.server = Server(self.config.connect_uri,
                                 use_server_version=True)
            self.server.auth.sign_in(authentication)
        except ServerResponseError as e:
            logger.error(e)
            self.report.report_failure(
                key="tableau-login",
                reason=f"Unable to Login with credentials provided"
                f"Reason: {str(e)}",
            )
        except Exception as e:
            logger.error(e)
            self.report.report_failure(key="tableau-login",
                                       reason=f"Unable to Login"
                                       f"Reason: {str(e)}")

    def get_connection_object(
        self,
        query: str,
        connection_type: str,
        query_filter: str,
        count: int = 0,
        current_count: int = 0,
    ) -> Tuple[dict, int, int]:
        query_data = query_metadata(self.server, query, connection_type, count,
                                    current_count, query_filter)

        if "errors" in query_data:
            self.report.report_warning(
                key="tableau-metadata",
                reason=
                f"Connection: {connection_type} Error: {query_data['errors']}",
            )

        connection_object = (query_data.get("data").get(connection_type, {})
                             if query_data.get("data") else {})

        total_count = connection_object.get("totalCount", 0)
        has_next_page = connection_object.get("pageInfo",
                                              {}).get("hasNextPage", False)
        return connection_object, total_count, has_next_page

    def emit_workbooks(self,
                       workbooks_page_size: int) -> Iterable[MetadataWorkUnit]:

        projects = (f"projectNameWithin: {json.dumps(self.config.projects)}"
                    if self.config.projects else "")

        workbook_connection, total_count, has_next_page = self.get_connection_object(
            workbook_graphql_query, "workbooksConnection", projects)

        current_count = 0
        while has_next_page:
            count = (workbooks_page_size if current_count +
                     workbooks_page_size < total_count else total_count -
                     current_count)
            (
                workbook_connection,
                total_count,
                has_next_page,
            ) = self.get_connection_object(
                workbook_graphql_query,
                "workbooksConnection",
                projects,
                count,
                current_count,
            )

            current_count += count

            for workbook in workbook_connection.get("nodes", []):
                yield from self.emit_workbook_as_container(workbook)
                yield from self.emit_sheets_as_charts(workbook)
                yield from self.emit_dashboards(workbook)
                yield from self.emit_embedded_datasource(workbook)
                yield from self.emit_upstream_tables()

    def _track_custom_sql_ids(self, field: dict) -> None:
        # Tableau shows custom sql datasource as a table in ColumnField.
        if field.get("__typename", "") == "ColumnField":
            for column in field.get("columns", []):
                table_id = column.get("table", {}).get("id")

                if (table_id is not None
                        and table_id not in self.custom_sql_ids_being_used):
                    self.custom_sql_ids_being_used.append(table_id)

    def _create_upstream_table_lineage(
            self,
            datasource: dict,
            project: str,
            is_custom_sql: bool = False) -> List[UpstreamClass]:
        upstream_tables = []

        for table in datasource.get("upstreamTables", []):
            # skip upstream tables when there is no column info when retrieving embedded datasource
            # and when table name is None
            # Schema details for these will be taken care in self.emit_custom_sql_ds()
            if not is_custom_sql and not table.get("columns"):
                continue
            elif table["name"] is None:
                continue

            upstream_db = table.get("database", {}).get("name", "")
            schema = self._get_schema(table.get("schema", ""), upstream_db)
            table_urn = make_table_urn(
                self.config.env,
                upstream_db,
                table.get("connectionType", ""),
                schema,
                table.get("name", ""),
            )

            upstream_table = UpstreamClass(
                dataset=table_urn,
                type=DatasetLineageTypeClass.TRANSFORMED,
            )
            upstream_tables.append(upstream_table)
            table_path = f"{project.replace('/', REPLACE_SLASH_CHAR)}/{datasource.get('name', '')}/{table.get('name', '')}"
            self.upstream_tables[table_urn] = (
                table.get("columns", []),
                table_path,
            )

        for datasource in datasource.get("upstreamDatasources", []):
            datasource_urn = builder.make_dataset_urn(self.platform,
                                                      datasource["id"],
                                                      self.config.env)
            upstream_table = UpstreamClass(
                dataset=datasource_urn,
                type=DatasetLineageTypeClass.TRANSFORMED,
            )
            upstream_tables.append(upstream_table)

        return upstream_tables

    def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]:
        count_on_query = len(self.custom_sql_ids_being_used)
        custom_sql_filter = "idWithin: {}".format(
            json.dumps(self.custom_sql_ids_being_used))
        custom_sql_connection, total_count, has_next_page = self.get_connection_object(
            custom_sql_graphql_query, "customSQLTablesConnection",
            custom_sql_filter)

        current_count = 0
        while has_next_page:
            count = (count_on_query if current_count +
                     count_on_query < total_count else total_count -
                     current_count)
            (
                custom_sql_connection,
                total_count,
                has_next_page,
            ) = self.get_connection_object(
                custom_sql_graphql_query,
                "customSQLTablesConnection",
                custom_sql_filter,
                count,
                current_count,
            )
            current_count += count

            unique_custom_sql = get_unique_custom_sql(
                custom_sql_connection.get("nodes", []))
            for csql in unique_custom_sql:
                csql_id: str = csql.get("id", "")
                csql_urn = builder.make_dataset_urn(self.platform, csql_id,
                                                    self.config.env)
                dataset_snapshot = DatasetSnapshot(
                    urn=csql_urn,
                    aspects=[],
                )

                # lineage from datasource -> custom sql source #
                yield from self._create_lineage_from_csql_datasource(
                    csql_urn, csql.get("datasources", []))

                # lineage from custom sql -> datasets/tables #
                columns = csql.get("columns", [])
                yield from self._create_lineage_to_upstream_tables(
                    csql_urn, columns)

                #  Schema Metadata
                schema_metadata = self.get_schema_metadata_for_custom_sql(
                    columns)
                if schema_metadata is not None:
                    dataset_snapshot.aspects.append(schema_metadata)

                # Browse path
                browse_paths = BrowsePathsClass(paths=[
                    f"/{self.config.env.lower()}/{self.platform}/Custom SQL/{csql.get('name', '')}/{csql_id}"
                ])
                dataset_snapshot.aspects.append(browse_paths)

                dataset_properties = DatasetPropertiesClass(
                    name=csql.get("name"), description=csql.get("description"))

                dataset_snapshot.aspects.append(dataset_properties)

                view_properties = ViewPropertiesClass(
                    materialized=False,
                    viewLanguage="SQL",
                    viewLogic=clean_query(csql.get("query", "")),
                )
                dataset_snapshot.aspects.append(view_properties)

                yield self.get_metadata_change_event(dataset_snapshot)
                yield self.get_metadata_change_proposal(
                    dataset_snapshot.urn,
                    aspect_name="subTypes",
                    aspect=SubTypesClass(typeNames=["View", "Custom SQL"]),
                )

    def get_schema_metadata_for_custom_sql(
            self, columns: List[dict]) -> Optional[SchemaMetadata]:
        schema_metadata = None
        for field in columns:
            # Datasource fields
            fields = []
            nativeDataType = field.get("remoteType", "UNKNOWN")
            TypeClass = FIELD_TYPE_MAPPING.get(nativeDataType, NullTypeClass)
            schema_field = SchemaField(
                fieldPath=field.get("name", ""),
                type=SchemaFieldDataType(type=TypeClass()),
                nativeDataType=nativeDataType,
                description=field.get("description", ""),
            )
            fields.append(schema_field)

            schema_metadata = SchemaMetadata(
                schemaName="test",
                platform=f"urn:li:dataPlatform:{self.platform}",
                version=0,
                fields=fields,
                hash="",
                platformSchema=OtherSchema(rawSchema=""),
            )
        return schema_metadata

    def _create_lineage_from_csql_datasource(
            self, csql_urn: str,
            csql_datasource: List[dict]) -> Iterable[MetadataWorkUnit]:
        for datasource in csql_datasource:
            datasource_urn = builder.make_dataset_urn(self.platform,
                                                      datasource.get("id", ""),
                                                      self.config.env)
            upstream_csql = UpstreamClass(
                dataset=csql_urn,
                type=DatasetLineageTypeClass.TRANSFORMED,
            )

            upstream_lineage = UpstreamLineage(upstreams=[upstream_csql])
            yield self.get_metadata_change_proposal(
                datasource_urn,
                aspect_name="upstreamLineage",
                aspect=upstream_lineage)

    def _create_lineage_to_upstream_tables(
            self, csql_urn: str,
            columns: List[dict]) -> Iterable[MetadataWorkUnit]:
        used_datasources = []
        # Get data sources from columns' reference fields.
        for field in columns:
            data_sources = [
                reference.get("datasource")
                for reference in field.get("referencedByFields", {})
                if reference.get("datasource") is not None
            ]

            for datasource in data_sources:
                if datasource.get("id", "") in used_datasources:
                    continue
                used_datasources.append(datasource.get("id", ""))
                upstream_tables = self._create_upstream_table_lineage(
                    datasource,
                    datasource.get("workbook", {}).get("projectName", ""),
                    True,
                )
                if upstream_tables:
                    upstream_lineage = UpstreamLineage(
                        upstreams=upstream_tables)
                    yield self.get_metadata_change_proposal(
                        csql_urn,
                        aspect_name="upstreamLineage",
                        aspect=upstream_lineage,
                    )

    def _get_schema_metadata_for_embedded_datasource(
            self, datasource_fields: List[dict]) -> Optional[SchemaMetadata]:
        fields = []
        schema_metadata = None
        for field in datasource_fields:
            # check datasource - custom sql relations from a field being referenced
            self._track_custom_sql_ids(field)

            nativeDataType = field.get("dataType", "UNKNOWN")
            TypeClass = FIELD_TYPE_MAPPING.get(nativeDataType, NullTypeClass)

            schema_field = SchemaField(
                fieldPath=field["name"],
                type=SchemaFieldDataType(type=TypeClass()),
                description=make_description_from_params(
                    field.get("description", ""), field.get("formula")),
                nativeDataType=nativeDataType,
                globalTags=get_tags_from_params([
                    field.get("role", ""),
                    field.get("__typename", ""),
                    field.get("aggregation", ""),
                ]) if self.config.ingest_tags else None,
            )
            fields.append(schema_field)

        if fields:
            schema_metadata = SchemaMetadata(
                schemaName="test",
                platform=f"urn:li:dataPlatform:{self.platform}",
                version=0,
                fields=fields,
                hash="",
                platformSchema=OtherSchema(rawSchema=""),
            )

        return schema_metadata

    def get_metadata_change_event(
        self, snap_shot: Union["DatasetSnapshot", "DashboardSnapshot",
                               "ChartSnapshot"]
    ) -> MetadataWorkUnit:
        mce = MetadataChangeEvent(proposedSnapshot=snap_shot)
        work_unit = MetadataWorkUnit(id=snap_shot.urn, mce=mce)
        self.report.report_workunit(work_unit)
        return work_unit

    def get_metadata_change_proposal(
        self,
        urn: str,
        aspect_name: str,
        aspect: Union["UpstreamLineage", "SubTypesClass"],
    ) -> MetadataWorkUnit:
        mcp = MetadataChangeProposalWrapper(
            entityType="dataset",
            changeType=ChangeTypeClass.UPSERT,
            entityUrn=urn,
            aspectName=aspect_name,
            aspect=aspect,
        )
        mcp_workunit = MetadataWorkUnit(
            id=f"tableau-{mcp.entityUrn}-{mcp.aspectName}",
            mcp=mcp,
            treat_errors_as_warnings=True,
        )
        self.report.report_workunit(mcp_workunit)
        return mcp_workunit

    def emit_datasource(self,
                        datasource: dict,
                        workbook: dict = None) -> Iterable[MetadataWorkUnit]:
        datasource_info = workbook
        if workbook is None:
            datasource_info = datasource

        project = (datasource_info.get("projectName", "").replace(
            "/", REPLACE_SLASH_CHAR) if datasource_info else "")
        datasource_id = datasource.get("id", "")
        datasource_name = f"{datasource.get('name')}.{datasource_id}"
        datasource_urn = builder.make_dataset_urn(self.platform, datasource_id,
                                                  self.config.env)
        if datasource_id not in self.datasource_ids_being_used:
            self.datasource_ids_being_used.append(datasource_id)

        dataset_snapshot = DatasetSnapshot(
            urn=datasource_urn,
            aspects=[],
        )

        # Browse path
        browse_paths = BrowsePathsClass(paths=[
            f"/{self.config.env.lower()}/{self.platform}/{project}/{datasource.get('name', '')}/{datasource_name}"
        ])
        dataset_snapshot.aspects.append(browse_paths)

        # Ownership
        owner = (self._get_ownership(
            datasource_info.get("owner", {}).get("username", ""))
                 if datasource_info else None)
        if owner is not None:
            dataset_snapshot.aspects.append(owner)

        # Dataset properties
        dataset_props = DatasetPropertiesClass(
            name=datasource.get("name"),
            description=datasource.get("description"),
            customProperties={
                "hasExtracts":
                str(datasource.get("hasExtracts", "")),
                "extractLastRefreshTime":
                datasource.get("extractLastRefreshTime", "") or "",
                "extractLastIncrementalUpdateTime":
                datasource.get("extractLastIncrementalUpdateTime", "") or "",
                "extractLastUpdateTime":
                datasource.get("extractLastUpdateTime", "") or "",
                "type":
                datasource.get("__typename", ""),
            },
        )
        dataset_snapshot.aspects.append(dataset_props)

        # Upstream Tables
        if datasource.get("upstreamTables") is not None:
            # datasource -> db table relations
            upstream_tables = self._create_upstream_table_lineage(
                datasource, project)

            if upstream_tables:
                upstream_lineage = UpstreamLineage(upstreams=upstream_tables)
                yield self.get_metadata_change_proposal(
                    datasource_urn,
                    aspect_name="upstreamLineage",
                    aspect=upstream_lineage,
                )

        # Datasource Fields
        schema_metadata = self._get_schema_metadata_for_embedded_datasource(
            datasource.get("fields", []))
        if schema_metadata is not None:
            dataset_snapshot.aspects.append(schema_metadata)

        yield self.get_metadata_change_event(dataset_snapshot)
        yield self.get_metadata_change_proposal(
            dataset_snapshot.urn,
            aspect_name="subTypes",
            aspect=SubTypesClass(typeNames=["Data Source"]),
        )

        if datasource.get("__typename") == "EmbeddedDatasource":
            yield from add_entity_to_container(self.gen_workbook_key(workbook),
                                               "dataset", dataset_snapshot.urn)

    def emit_published_datasources(self) -> Iterable[MetadataWorkUnit]:
        count_on_query = len(self.datasource_ids_being_used)
        datasource_filter = "idWithin: {}".format(
            json.dumps(self.datasource_ids_being_used))
        (
            published_datasource_conn,
            total_count,
            has_next_page,
        ) = self.get_connection_object(
            published_datasource_graphql_query,
            "publishedDatasourcesConnection",
            datasource_filter,
        )

        current_count = 0
        while has_next_page:
            count = (count_on_query if current_count +
                     count_on_query < total_count else total_count -
                     current_count)
            (
                published_datasource_conn,
                total_count,
                has_next_page,
            ) = self.get_connection_object(
                published_datasource_graphql_query,
                "publishedDatasourcesConnection",
                datasource_filter,
                count,
                current_count,
            )

            current_count += count
            for datasource in published_datasource_conn.get("nodes", []):
                yield from self.emit_datasource(datasource)

    def emit_upstream_tables(self) -> Iterable[MetadataWorkUnit]:
        for (table_urn, (columns, path)) in self.upstream_tables.items():
            dataset_snapshot = DatasetSnapshot(
                urn=table_urn,
                aspects=[],
            )
            # Browse path
            browse_paths = BrowsePathsClass(
                paths=[f"/{self.config.env.lower()}/{self.platform}/{path}"])
            dataset_snapshot.aspects.append(browse_paths)

            fields = []
            for field in columns:
                nativeDataType = field.get("remoteType", "UNKNOWN")
                TypeClass = FIELD_TYPE_MAPPING.get(nativeDataType,
                                                   NullTypeClass)

                schema_field = SchemaField(
                    fieldPath=field["name"],
                    type=SchemaFieldDataType(type=TypeClass()),
                    description="",
                    nativeDataType=nativeDataType,
                )

                fields.append(schema_field)

            schema_metadata = SchemaMetadata(
                schemaName="test",
                platform=f"urn:li:dataPlatform:{self.platform}",
                version=0,
                fields=fields,
                hash="",
                platformSchema=OtherSchema(rawSchema=""),
            )
            if schema_metadata is not None:
                dataset_snapshot.aspects.append(schema_metadata)

            yield self.get_metadata_change_event(dataset_snapshot)

    # Older tableau versions do not support fetching sheet's upstreamDatasources,
    # This achieves the same effect by using datasource's downstreamSheets
    def get_sheetwise_upstream_datasources(self, workbook: dict) -> dict:
        sheet_upstream_datasources: dict = {}

        for embedded_ds in workbook.get("embeddedDatasources", []):
            for sheet in embedded_ds.get("downstreamSheets", []):
                if sheet.get("id") not in sheet_upstream_datasources:
                    sheet_upstream_datasources[sheet.get("id")] = set()
                sheet_upstream_datasources[sheet.get("id")].add(
                    embedded_ds.get("id"))

        for published_ds in workbook.get("upstreamDatasources", []):
            for sheet in published_ds.get("downstreamSheets", []):
                if sheet.get("id") not in sheet_upstream_datasources:
                    sheet_upstream_datasources[sheet.get("id")] = set()
                sheet_upstream_datasources[sheet.get("id")].add(
                    published_ds.get("id"))
        return sheet_upstream_datasources

    def emit_sheets_as_charts(self,
                              workbook: Dict) -> Iterable[MetadataWorkUnit]:
        sheet_upstream_datasources = self.get_sheetwise_upstream_datasources(
            workbook)
        for sheet in workbook.get("sheets", []):
            chart_snapshot = ChartSnapshot(
                urn=builder.make_chart_urn(self.platform, sheet.get("id")),
                aspects=[],
            )

            creator = workbook.get("owner", {}).get("username", "")
            created_at = sheet.get("createdAt", datetime.now())
            updated_at = sheet.get("updatedAt", datetime.now())
            last_modified = self.get_last_modified(creator, created_at,
                                                   updated_at)

            if sheet.get("path"):
                site_part = f"/site/{self.config.site}" if self.config.site else ""
                sheet_external_url = (
                    f"{self.config.connect_uri}/#{site_part}/views/{sheet.get('path')}"
                )
            elif sheet.get("containedInDashboards"):
                # sheet contained in dashboard
                site_part = f"/t/{self.config.site}" if self.config.site else ""
                dashboard_path = sheet.get("containedInDashboards")[0].get(
                    "path", "")
                sheet_external_url = f"{self.config.connect_uri}{site_part}/authoring/{dashboard_path}/{sheet.get('name', '')}"
            else:
                # hidden or viz-in-tooltip sheet
                sheet_external_url = None
            fields = {}
            for field in sheet.get("datasourceFields", ""):
                description = make_description_from_params(
                    get_field_value_in_sheet(field, "description"),
                    get_field_value_in_sheet(field, "formula"),
                )
                fields[get_field_value_in_sheet(field, "name")] = description

            # datasource urn
            datasource_urn = []
            data_sources = sheet_upstream_datasources.get(
                sheet.get("id"), set())

            for ds_id in data_sources:
                if ds_id is None or not ds_id:
                    continue
                ds_urn = builder.make_dataset_urn(self.platform, ds_id,
                                                  self.config.env)
                datasource_urn.append(ds_urn)
                if ds_id not in self.datasource_ids_being_used:
                    self.datasource_ids_being_used.append(ds_id)

            # Chart Info
            chart_info = ChartInfoClass(
                description="",
                title=sheet.get("name", ""),
                lastModified=last_modified,
                externalUrl=sheet_external_url,
                inputs=sorted(datasource_urn),
                customProperties=fields,
            )
            chart_snapshot.aspects.append(chart_info)

            # Browse path
            browse_path = BrowsePathsClass(paths=[
                f"/{self.platform}/{workbook.get('projectName', '').replace('/', REPLACE_SLASH_CHAR)}"
                f"/{workbook.get('name', '')}"
                f"/{sheet.get('name', '').replace('/', REPLACE_SLASH_CHAR)}"
            ])
            chart_snapshot.aspects.append(browse_path)

            # Ownership
            owner = self._get_ownership(creator)
            if owner is not None:
                chart_snapshot.aspects.append(owner)

            #  Tags
            tag_list = sheet.get("tags", [])
            if tag_list and self.config.ingest_tags:
                tag_list_str = [
                    t.get("name", "").upper() for t in tag_list
                    if t is not None
                ]
                chart_snapshot.aspects.append(
                    builder.make_global_tag_aspect_with_tag_list(tag_list_str))

            yield self.get_metadata_change_event(chart_snapshot)

            yield from add_entity_to_container(self.gen_workbook_key(workbook),
                                               "chart", chart_snapshot.urn)

    def emit_workbook_as_container(
            self, workbook: Dict) -> Iterable[MetadataWorkUnit]:

        workbook_container_key = self.gen_workbook_key(workbook)
        creator = workbook.get("owner", {}).get("username")

        owner_urn = (builder.make_user_urn(creator) if
                     (creator and self.config.ingest_owner) else None)

        site_part = f"/site/{self.config.site}" if self.config.site else ""
        workbook_uri = workbook.get("uri", "")
        workbook_part = (workbook_uri[workbook_uri.index("/workbooks/"):]
                         if workbook.get("uri") else None)
        workbook_external_url = (
            f"{self.config.connect_uri}/#{site_part}{workbook_part}"
            if workbook_part else None)

        tag_list = workbook.get("tags", [])
        tag_list_str = (
            [t.get("name", "").upper() for t in tag_list if t is not None] if
            (tag_list and self.config.ingest_tags) else None)

        container_workunits = gen_containers(
            container_key=workbook_container_key,
            name=workbook.get("name", ""),
            sub_types=["Workbook"],
            description=workbook.get("description"),
            owner_urn=owner_urn,
            external_url=workbook_external_url,
            tags=tag_list_str,
        )

        for wu in container_workunits:
            self.report.report_workunit(wu)
            yield wu

    def gen_workbook_key(self, workbook):
        return WorkbookKey(platform=self.platform,
                           instance=None,
                           workbook_id=workbook["id"])

    def emit_dashboards(self, workbook: Dict) -> Iterable[MetadataWorkUnit]:
        for dashboard in workbook.get("dashboards", []):
            dashboard_snapshot = DashboardSnapshot(
                urn=builder.make_dashboard_urn(self.platform,
                                               dashboard.get("id", "")),
                aspects=[],
            )

            creator = workbook.get("owner", {}).get("username", "")
            created_at = dashboard.get("createdAt", datetime.now())
            updated_at = dashboard.get("updatedAt", datetime.now())
            last_modified = self.get_last_modified(creator, created_at,
                                                   updated_at)

            site_part = f"/site/{self.config.site}" if self.config.site else ""
            dashboard_external_url = f"{self.config.connect_uri}/#{site_part}/views/{dashboard.get('path', '')}"
            title = dashboard.get("name", "").replace("/",
                                                      REPLACE_SLASH_CHAR) or ""
            chart_urns = [
                builder.make_chart_urn(self.platform, sheet.get("id"))
                for sheet in dashboard.get("sheets", [])
            ]
            dashboard_info_class = DashboardInfoClass(
                description="",
                title=title,
                charts=chart_urns,
                lastModified=last_modified,
                dashboardUrl=dashboard_external_url,
                customProperties={},
            )
            dashboard_snapshot.aspects.append(dashboard_info_class)

            # browse path
            browse_paths = BrowsePathsClass(paths=[
                f"/{self.platform}/{workbook.get('projectName', '').replace('/', REPLACE_SLASH_CHAR)}"
                f"/{workbook.get('name', '').replace('/', REPLACE_SLASH_CHAR)}"
                f"/{title}"
            ])
            dashboard_snapshot.aspects.append(browse_paths)

            # Ownership
            owner = self._get_ownership(creator)
            if owner is not None:
                dashboard_snapshot.aspects.append(owner)

            yield self.get_metadata_change_event(dashboard_snapshot)

            yield from add_entity_to_container(self.gen_workbook_key(workbook),
                                               "dashboard",
                                               dashboard_snapshot.urn)

    def emit_embedded_datasource(self,
                                 workbook: Dict) -> Iterable[MetadataWorkUnit]:
        for datasource in workbook.get("embeddedDatasources", []):
            yield from self.emit_datasource(datasource, workbook)

    @lru_cache(maxsize=None)
    def _get_schema(self, schema_provided: str, database: str) -> str:
        schema = schema_provided
        if not schema_provided and database in self.config.default_schema_map:
            schema = self.config.default_schema_map[database]

        return schema

    @lru_cache(maxsize=None)
    def get_last_modified(self, creator: str, created_at: bytes,
                          updated_at: bytes) -> ChangeAuditStamps:
        last_modified = ChangeAuditStamps()
        if creator:
            modified_actor = builder.make_user_urn(creator)
            created_ts = int(dp.parse(created_at).timestamp() * 1000)
            modified_ts = int(dp.parse(updated_at).timestamp() * 1000)
            last_modified = ChangeAuditStamps(
                created=AuditStamp(time=created_ts, actor=modified_actor),
                lastModified=AuditStamp(time=modified_ts,
                                        actor=modified_actor),
            )
        return last_modified

    @lru_cache(maxsize=None)
    def _get_ownership(self, user: str) -> Optional[OwnershipClass]:
        if self.config.ingest_owner and user:
            owner_urn = builder.make_user_urn(user)
            ownership: OwnershipClass = OwnershipClass(owners=[
                OwnerClass(
                    owner=owner_urn,
                    type=OwnershipTypeClass.DATAOWNER,
                )
            ])
            return ownership

        return None

    @classmethod
    def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
        config = TableauConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        if self.server is None or not self.server.is_signed_in():
            return
        try:
            yield from self.emit_workbooks(self.config.workbooks_page_size)
            if self.datasource_ids_being_used:
                yield from self.emit_published_datasources()
            if self.custom_sql_ids_being_used:
                yield from self.emit_custom_sql_datasources()
        except MetadataQueryException as md_exception:
            self.report.report_failure(
                key="tableau-metadata",
                reason=
                f"Unable to retrieve metadata from tableau. Information: {str(md_exception)}",
            )

    def get_report(self) -> SourceReport:
        return self.report