Exemple #1
0
def get_column_type(
    report: SourceReport, dataset_name: str, column_type: str
) -> SchemaFieldDataType:
    """
    Maps known DBT types to datahub types
    """
    column_type_stripped = ""

    pattern = re.compile(r"[\w ]+")  # drop all non alphanumerics
    match = pattern.match(column_type)
    if match is not None:
        column_type_stripped = match.group()

    TypeClass: Any = None
    for key in _field_type_mapping.keys():
        if key == column_type_stripped:
            TypeClass = _field_type_mapping[column_type_stripped]
            break

    if TypeClass is None:
        report.report_warning(
            dataset_name, f"unable to map type {column_type} to metadata schema"
        )
        TypeClass = NullTypeClass

    return SchemaFieldDataType(type=TypeClass())
Exemple #2
0
 def _get_tags_from_field_type(
         field_type: ViewFieldType,
         reporter: SourceReport) -> Optional[GlobalTagsClass]:
     if field_type in LookerUtil.type_to_tag_map:
         return GlobalTagsClass(tags=[
             TagAssociationClass(tag=tag_name)
             for tag_name in LookerUtil.type_to_tag_map[field_type]
         ])
     else:
         reporter.report_warning(
             "lookml",
             f"Failed to map view field type {field_type}. Won't emit tags for it",
         )
         return None
Exemple #3
0
    def _extract_metadata_from_sql_query(
        cls: Type,
        reporter: SourceReport,
        parse_table_names_from_sql: bool,
        sql_parser_path: str,
        view_name: str,
        sql_table_name: Optional[str],
        derived_table: dict,
        fields: List[ViewField],
    ) -> Tuple[List[ViewField], List[str]]:
        sql_table_names: List[str] = []
        if parse_table_names_from_sql and "sql" in derived_table:
            logger.debug(
                f"Parsing sql from derived table section of view: {view_name}")
            sql_query = derived_table["sql"]

            # Skip queries that contain liquid variables. We currently don't parse them correctly
            if "{%" in sql_query:
                logger.debug(
                    f"{view_name}: Skipping sql_query parsing since it contains liquid variables"
                )
                return fields, sql_table_names
            # Looker supports sql fragments that omit the SELECT and FROM parts of the query
            # Add those in if we detect that it is missing
            if not re.search(r"SELECT\s", sql_query, flags=re.I):
                # add a SELECT clause at the beginning
                sql_query = "SELECT " + sql_query
            if not re.search(r"FROM\s", sql_query, flags=re.I):
                # add a FROM clause at the end
                sql_query = f"{sql_query} FROM {sql_table_name if sql_table_name is not None else view_name}"
                # Get the list of tables in the query
            try:
                sql_info = cls._get_sql_info(sql_query, sql_parser_path)
                sql_table_names = sql_info.table_names
                column_names = sql_info.column_names
                if fields == []:
                    # it seems like the view is defined purely as sql, let's try using the column names to populate the schema
                    fields = [
                        # set types to unknown for now as our sql parser doesn't give us column types yet
                        ViewField(c, "unknown", "", ViewFieldType.UNKNOWN)
                        for c in column_names
                    ]
            except Exception as e:
                reporter.report_warning(
                    f"looker-view-{view_name}",
                    f"Failed to parse sql query, lineage will not be accurate. Exception: {e}",
                )

        return fields, sql_table_names
Exemple #4
0
def get_column_type(report: SourceReport, dataset_name: str,
                    column_type: str) -> SchemaFieldDataType:
    """
    Maps known DBT types to datahub types
    """
    TypeClass: Any = _field_type_mapping.get(column_type)

    if TypeClass is None:
        # attempt Postgres modified type
        TypeClass = resolve_postgres_modified_type(column_type)

    # if still not found, report the warning
    if TypeClass is None:
        report.report_warning(
            dataset_name,
            f"unable to map type {column_type} to metadata schema")
        TypeClass = NullTypeClass

    return SchemaFieldDataType(type=TypeClass())
Exemple #5
0
    def _get_field_type(native_type: str,
                        reporter: SourceReport) -> SchemaFieldDataType:

        type_class = LookerUtil.field_type_mapping.get(native_type)

        if type_class is None:

            # attempt Postgres modified type
            type_class = resolve_postgres_modified_type(native_type)

        # if still not found, report a warning
        if type_class is None:
            reporter.report_warning(
                native_type,
                f"The type '{native_type}' is not recognized for field type, setting as NullTypeClass.",
            )
            type_class = NullTypeClass

        data_type = SchemaFieldDataType(type=type_class())
        return data_type
Exemple #6
0
def get_column_type(report: SourceReport, dataset_name: str,
                    column_type: str) -> SchemaFieldDataType:
    """
    Maps known Spark types to datahub types
    """
    TypeClass: Any = None

    for field_type, type_class in _field_type_mapping.items():
        if isinstance(column_type, field_type):
            TypeClass = type_class
            break

    # if still not found, report the warning
    if TypeClass is None:
        report.report_warning(
            dataset_name,
            f"unable to map type {column_type} to metadata schema")
        TypeClass = NullTypeClass

    return SchemaFieldDataType(type=TypeClass())
Exemple #7
0
class MetabaseSource(Source):
    """
    This plugin extracts Charts, dashboards, and associated metadata. This plugin is in beta and has only been tested
    on PostgreSQL and H2 database.
    ### Dashboard

    [/api/dashboard](https://www.metabase.com/docs/latest/api-documentation.html#dashboard) endpoint is used to
    retrieve the following dashboard information.

    - Title and description
    - Last edited by
    - Owner
    - Link to the dashboard in Metabase
    - Associated charts

    ### Chart

    [/api/card](https://www.metabase.com/docs/latest/api-documentation.html#card) endpoint is used to
    retrieve the following information.

    - Title and description
    - Last edited by
    - Owner
    - Link to the chart in Metabase
    - Datasource and lineage

    The following properties for a chart are ingested in DataHub.

    | Name          | Description                                     |
    | ------------- | ----------------------------------------------- |
    | `Dimensions`  | Column names                                    |
    | `Filters`     | Any filters applied to the chart                |
    | `Metrics`     | All columns that are being used for aggregation |


    """

    config: MetabaseConfig
    report: SourceReport
    platform = "metabase"

    def __hash__(self):
        return id(self)

    def __init__(self, ctx: PipelineContext, config: MetabaseConfig):
        super().__init__(ctx)
        self.config = config
        self.report = SourceReport()

        login_response = requests.post(
            f"{self.config.connect_uri}/api/session",
            None,
            {
                "username": self.config.username,
                "password": self.config.password,
            },
        )

        login_response.raise_for_status()
        self.access_token = login_response.json().get("id", "")

        self.session = requests.session()
        self.session.headers.update({
            "X-Metabase-Session": f"{self.access_token}",
            "Content-Type": "application/json",
            "Accept": "*/*",
        })

        # Test the connection
        try:
            test_response = self.session.get(
                f"{self.config.connect_uri}/api/user/current")
            test_response.raise_for_status()
        except HTTPError as e:
            self.report.report_failure(
                key="metabase-session",
                reason=
                f"Unable to retrieve user {self.config.username} information. %s"
                % str(e),
            )

    def close(self) -> None:
        response = requests.delete(
            f"{self.config.connect_uri}/api/session",
            headers={"X-Metabase-Session": self.access_token},
        )
        if response.status_code not in (200, 204):
            self.report.report_failure(
                key="metabase-session",
                reason=f"Unable to logout for user {self.config.username}",
            )

    def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
        try:
            dashboard_response = self.session.get(
                f"{self.config.connect_uri}/api/dashboard")
            dashboard_response.raise_for_status()
            dashboards = dashboard_response.json()

            for dashboard_info in dashboards:
                dashboard_snapshot = self.construct_dashboard_from_api_data(
                    dashboard_info)
                if dashboard_snapshot is not None:
                    mce = MetadataChangeEvent(
                        proposedSnapshot=dashboard_snapshot)
                    wu = MetadataWorkUnit(id=dashboard_snapshot.urn, mce=mce)
                    self.report.report_workunit(wu)
                    yield wu

        except HTTPError as http_error:
            self.report.report_failure(
                key="metabase-dashboard",
                reason=f"Unable to retrieve dashboards. "
                f"Reason: {str(http_error)}",
            )

    @staticmethod
    def get_timestamp_millis_from_ts_string(ts_str: str) -> int:
        """
        Converts the given timestamp string to milliseconds. If parsing fails,
        returns the utc-now in milliseconds.
        """
        try:
            return int(dp.parse(ts_str).timestamp() * 1000)
        except (dp.ParserError, OverflowError):
            return int(datetime.utcnow().timestamp() * 1000)

    def construct_dashboard_from_api_data(
            self, dashboard_info: dict) -> Optional[DashboardSnapshot]:

        dashboard_id = dashboard_info.get("id", "")
        dashboard_url = f"{self.config.connect_uri}/api/dashboard/{dashboard_id}"
        try:
            dashboard_response = self.session.get(dashboard_url)
            dashboard_response.raise_for_status()
            dashboard_details = dashboard_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-dashboard-{dashboard_id}",
                reason=f"Unable to retrieve dashboard. "
                f"Reason: {str(http_error)}",
            )
            return None

        dashboard_urn = builder.make_dashboard_urn(
            self.platform, dashboard_details.get("id", ""))
        dashboard_snapshot = DashboardSnapshot(
            urn=dashboard_urn,
            aspects=[],
        )
        last_edit_by = dashboard_details.get("last-edit-info") or {}
        modified_actor = builder.make_user_urn(
            last_edit_by.get("email", "unknown"))
        modified_ts = self.get_timestamp_millis_from_ts_string(
            f"{last_edit_by.get('timestamp')}")
        title = dashboard_details.get("name", "") or ""
        description = dashboard_details.get("description", "") or ""
        last_modified = ChangeAuditStamps(
            created=AuditStamp(time=modified_ts, actor=modified_actor),
            lastModified=AuditStamp(time=modified_ts, actor=modified_actor),
        )

        chart_urns = []
        cards_data = dashboard_details.get("ordered_cards", "{}")
        for card_info in cards_data:
            chart_urn = builder.make_chart_urn(self.platform,
                                               card_info.get("id", ""))
            chart_urns.append(chart_urn)

        dashboard_info_class = DashboardInfoClass(
            description=description,
            title=title,
            charts=chart_urns,
            lastModified=last_modified,
            dashboardUrl=f"{self.config.connect_uri}/dashboard/{dashboard_id}",
            customProperties={},
        )
        dashboard_snapshot.aspects.append(dashboard_info_class)

        # Ownership
        ownership = self._get_ownership(dashboard_details.get(
            "creator_id", ""))
        if ownership is not None:
            dashboard_snapshot.aspects.append(ownership)

        return dashboard_snapshot

    @lru_cache(maxsize=None)
    def _get_ownership(self, creator_id: int) -> Optional[OwnershipClass]:
        user_info_url = f"{self.config.connect_uri}/api/user/{creator_id}"
        try:
            user_info_response = self.session.get(user_info_url)
            user_info_response.raise_for_status()
            user_details = user_info_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-user-{creator_id}",
                reason=f"Unable to retrieve User info. "
                f"Reason: {str(http_error)}",
            )
            return None

        owner_urn = builder.make_user_urn(user_details.get("email", ""))
        if owner_urn is not None:
            ownership: OwnershipClass = OwnershipClass(owners=[
                OwnerClass(
                    owner=owner_urn,
                    type=OwnershipTypeClass.DATAOWNER,
                )
            ])
            return ownership

        return None

    def emit_card_mces(self) -> Iterable[MetadataWorkUnit]:
        try:
            card_response = self.session.get(
                f"{self.config.connect_uri}/api/card")
            card_response.raise_for_status()
            cards = card_response.json()

            for card_info in cards:
                chart_snapshot = self.construct_card_from_api_data(card_info)
                if chart_snapshot is not None:
                    mce = MetadataChangeEvent(proposedSnapshot=chart_snapshot)
                    wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
                    self.report.report_workunit(wu)
                    yield wu

        except HTTPError as http_error:
            self.report.report_failure(
                key="metabase-cards",
                reason=f"Unable to retrieve cards. "
                f"Reason: {str(http_error)}",
            )
            return None

    def construct_card_from_api_data(
            self, card_data: dict) -> Optional[ChartSnapshot]:
        card_id = card_data.get("id", "")
        card_url = f"{self.config.connect_uri}/api/card/{card_id}"
        try:
            card_response = self.session.get(card_url)
            card_response.raise_for_status()
            card_details = card_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-card-{card_id}",
                reason=f"Unable to retrieve Card info. "
                f"Reason: {str(http_error)}",
            )
            return None

        chart_urn = builder.make_chart_urn(self.platform, card_id)
        chart_snapshot = ChartSnapshot(
            urn=chart_urn,
            aspects=[],
        )

        last_edit_by = card_details.get("last-edit-info") or {}
        modified_actor = builder.make_user_urn(
            last_edit_by.get("email", "unknown"))
        modified_ts = self.get_timestamp_millis_from_ts_string(
            f"{last_edit_by.get('timestamp')}")
        last_modified = ChangeAuditStamps(
            created=AuditStamp(time=modified_ts, actor=modified_actor),
            lastModified=AuditStamp(time=modified_ts, actor=modified_actor),
        )

        chart_type = self._get_chart_type(card_details.get("id", ""),
                                          card_details.get("display"))
        description = card_details.get("description") or ""
        title = card_details.get("name") or ""
        datasource_urn = self.get_datasource_urn(card_details)
        custom_properties = self.construct_card_custom_properties(card_details)

        chart_info = ChartInfoClass(
            type=chart_type,
            description=description,
            title=title,
            lastModified=last_modified,
            chartUrl=f"{self.config.connect_uri}/card/{card_id}",
            inputs=datasource_urn,
            customProperties=custom_properties,
        )
        chart_snapshot.aspects.append(chart_info)

        if card_details.get("query_type", "") == "native":
            raw_query = (card_details.get("dataset_query",
                                          {}).get("native",
                                                  {}).get("query", ""))
            chart_query_native = ChartQueryClass(
                rawQuery=raw_query,
                type=ChartQueryTypeClass.SQL,
            )
            chart_snapshot.aspects.append(chart_query_native)

        # Ownership
        ownership = self._get_ownership(card_details.get("creator_id", ""))
        if ownership is not None:
            chart_snapshot.aspects.append(ownership)

        return chart_snapshot

    def _get_chart_type(self, card_id: int,
                        display_type: str) -> Optional[str]:
        type_mapping = {
            "table": ChartTypeClass.TABLE,
            "bar": ChartTypeClass.BAR,
            "line": ChartTypeClass.LINE,
            "row": ChartTypeClass.BAR,
            "area": ChartTypeClass.AREA,
            "pie": ChartTypeClass.PIE,
            "funnel": ChartTypeClass.BAR,
            "scatter": ChartTypeClass.SCATTER,
            "scalar": ChartTypeClass.TEXT,
            "smartscalar": ChartTypeClass.TEXT,
            "pivot": ChartTypeClass.TABLE,
            "waterfall": ChartTypeClass.BAR,
            "progress": None,
            "combo": None,
            "gauge": None,
            "map": None,
        }
        if not display_type:
            self.report.report_warning(
                key=f"metabase-card-{card_id}",
                reason=f"Card type {display_type} is missing. Setting to None",
            )
            return None
        try:
            chart_type = type_mapping[display_type]
        except KeyError:
            self.report.report_warning(
                key=f"metabase-card-{card_id}",
                reason=
                f"Chart type {display_type} not supported. Setting to None",
            )
            chart_type = None

        return chart_type

    def construct_card_custom_properties(self, card_details: dict) -> Dict:
        result_metadata = card_details.get("result_metadata") or []
        metrics, dimensions = [], []
        for meta_data in result_metadata:
            display_name = meta_data.get("display_name", "") or ""
            metrics.append(display_name) if "aggregation" in meta_data.get(
                "field_ref", "") else dimensions.append(display_name)

        filters = (card_details.get("dataset_query",
                                    {}).get("query", {})).get("filter", [])

        custom_properties = {
            "Metrics": ", ".join(metrics),
            "Filters": f"{filters}" if len(filters) else "",
            "Dimensions": ", ".join(dimensions),
        }

        return custom_properties

    def get_datasource_urn(self, card_details):
        platform, database_name, platform_instance = self.get_datasource_from_id(
            card_details.get("database_id", ""))
        query_type = card_details.get("dataset_query", {}).get("type", {})
        source_paths = set()

        if query_type == "query":
            source_table_id = (card_details.get("dataset_query", {}).get(
                "query", {}).get("source-table"))
            if source_table_id is not None:
                schema_name, table_name = self.get_source_table_from_id(
                    source_table_id)
                if table_name:
                    source_paths.add(
                        f"{schema_name + '.' if schema_name else ''}{table_name}"
                    )
        else:
            try:
                raw_query = (card_details.get("dataset_query",
                                              {}).get("native",
                                                      {}).get("query", ""))
                parser = LineageRunner(raw_query)

                for table in parser.source_tables:
                    sources = str(table).split(".")
                    source_schema, source_table = sources[-2], sources[-1]
                    if source_schema == "<default>":
                        source_schema = str(self.config.default_schema)

                    source_paths.add(f"{source_schema}.{source_table}")
            except Exception as e:
                self.report.report_failure(
                    key="metabase-query",
                    reason=f"Unable to retrieve lineage from query. "
                    f"Query: {raw_query} "
                    f"Reason: {str(e)} ",
                )
                return None

        # Create dataset URNs
        dataset_urn = []
        dbname = f"{database_name + '.' if database_name else ''}"
        source_tables = list(map(lambda tbl: f"{dbname}{tbl}", source_paths))
        dataset_urn = [
            builder.make_dataset_urn_with_platform_instance(
                platform=platform,
                name=name,
                platform_instance=platform_instance,
                env=self.config.env,
            ) for name in source_tables
        ]

        return dataset_urn

    @lru_cache(maxsize=None)
    def get_source_table_from_id(self, table_id):
        try:
            dataset_response = self.session.get(
                f"{self.config.connect_uri}/api/table/{table_id}")
            dataset_response.raise_for_status()
            dataset_json = dataset_response.json()
            schema = dataset_json.get("schema", "")
            name = dataset_json.get("name", "")
            return schema, name

        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-table-{table_id}",
                reason=f"Unable to retrieve source table. "
                f"Reason: {str(http_error)}",
            )

        return None, None

    @lru_cache(maxsize=None)
    def get_datasource_from_id(self, datasource_id):
        try:
            dataset_response = self.session.get(
                f"{self.config.connect_uri}/api/database/{datasource_id}")
            dataset_response.raise_for_status()
            dataset_json = dataset_response.json()
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"metabase-datasource-{datasource_id}",
                reason=f"Unable to retrieve Datasource. "
                f"Reason: {str(http_error)}",
            )
            return None, None

        # Map engine names to what datahub expects in
        # https://github.com/datahub-project/datahub/blob/master/metadata-service/war/src/main/resources/boot/data_platforms.json
        engine = dataset_json.get("engine", "")
        platform = engine

        engine_mapping = {
            "sparksql": "spark",
            "mongo": "mongodb",
            "presto-jdbc": "presto",
            "sqlserver": "mssql",
            "bigquery-cloud-sdk": "bigquery",
        }

        if self.config.engine_platform_map is not None:
            engine_mapping.update(self.config.engine_platform_map)

        if engine in engine_mapping:
            platform = engine_mapping[engine]
        else:
            self.report.report_warning(
                key=f"metabase-platform-{datasource_id}",
                reason=
                f"Platform was not found in DataHub. Using {platform} name as is",
            )
        # Set platform_instance if configuration provides a mapping from platform name to instance
        platform_instance = (self.config.platform_instance_map.get(platform)
                             if self.config.platform_instance_map else None)

        field_for_dbname_mapping = {
            "postgres": "dbname",
            "sparksql": "dbname",
            "mongo": "dbname",
            "redshift": "db",
            "snowflake": "db",
            "presto-jdbc": "catalog",
            "presto": "catalog",
            "mysql": "dbname",
            "sqlserver": "db",
        }

        dbname = (dataset_json.get("details", {}).get(
            field_for_dbname_mapping[engine])
                  if engine in field_for_dbname_mapping else None)

        if (self.config.database_alias_map is not None
                and platform in self.config.database_alias_map):
            dbname = self.config.database_alias_map[platform]
        else:
            self.report.report_warning(
                key=f"metabase-dbname-{datasource_id}",
                reason=
                f"Cannot determine database name for platform: {platform}",
            )

        return platform, dbname, platform_instance

    @classmethod
    def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
        config = MetabaseConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        yield from self.emit_dashboard_mces()
        yield from self.emit_card_mces()

    def get_report(self) -> SourceReport:
        return self.report
Exemple #8
0
class LDAPSource(Source):
    config: LDAPSourceConfig
    report: SourceReport

    def __init__(self, ctx: PipelineContext, config: LDAPSourceConfig):
        super().__init__(ctx)
        self.config = config
        self.report = SourceReport()

        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_ALLOW)
        ldap.set_option(ldap.OPT_REFERRALS, 0)

        self.ldap_client = ldap.initialize(self.config.ldap_server)
        self.ldap_client.protocol_version = 3

        try:
            self.ldap_client.simple_bind_s(
                self.config.ldap_user, self.config.ldap_password
            )
        except ldap.LDAPError as e:
            raise ConfigurationError("LDAP connection failed") from e

        self.lc = create_controls(self.config.page_size)

    @classmethod
    def create(cls, config_dict, ctx):
        config = LDAPSourceConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        cookie = True
        while cookie:
            try:
                msgid = self.ldap_client.search_ext(
                    self.config.base_dn,
                    ldap.SCOPE_SUBTREE,
                    self.config.filter,
                    serverctrls=[self.lc],
                )
                rtype, rdata, rmsgid, serverctrls = self.ldap_client.result3(msgid)
            except ldap.LDAPError as e:
                self.report.report_failure(
                    "ldap-control", "LDAP search failed: {}".format(e)
                )
                break

            for dn, attrs in rdata:
                # TODO: create groups if 'organizationalUnit' in attrs['objectClass']

                if (
                    b"inetOrgPerson" in attrs["objectClass"]
                    or b"posixAccount" in attrs["objectClass"]
                ):
                    yield from self.handle_user(dn, attrs)

            pctrls = get_pctrls(serverctrls)
            if not pctrls:
                self.report.report_failure(
                    "ldap-control", "Server ignores RFC 2696 control."
                )
                break

            cookie = set_cookie(self.lc, pctrls, self.config.page_size)

    def handle_user(self, dn, attrs) -> Iterable[MetadataWorkUnit]:
        """
        Handle a DN and attributes by adding manager info and constructing a
        work unit based on the information.
        """
        manager_ldap = None
        if "manager" in attrs:
            try:
                m_cn = attrs["manager"][0].split(b",")[0]
                manager_msgid = self.ldap_client.search_ext(
                    self.config.base_dn,
                    ldap.SCOPE_SUBTREE,
                    f"({m_cn.decode()})",
                    serverctrls=[self.lc],
                )
                m_dn, m_attrs = self.ldap_client.result3(manager_msgid)[1][0]
                manager_ldap = guess_person_ldap(m_dn, m_attrs)
            except ldap.LDAPError as e:
                self.report.report_warning(
                    dn, "manager LDAP search failed: {}".format(e)
                )

        mce = self.build_corp_user_mce(dn, attrs, manager_ldap)
        if mce:
            wu = MetadataWorkUnit(dn, mce)
            self.report.report_workunit(wu)
            yield wu
        yield from []

    def build_corp_user_mce(
        self, dn, attrs, manager_ldap
    ) -> Optional[MetadataChangeEvent]:
        """
        Create the MetadataChangeEvent via DN and attributes.
        """
        ldap = guess_person_ldap(dn, attrs)
        full_name = attrs["cn"][0].decode()
        first_name = attrs["givenName"][0].decode()
        last_name = attrs["sn"][0].decode()
        email = (attrs["mail"][0]).decode() if "mail" in attrs else None
        display_name = (
            (attrs["displayName"][0]).decode() if "displayName" in attrs else full_name
        )
        department = (
            (attrs["departmentNumber"][0]).decode()
            if "departmentNumber" in attrs
            else None
        )
        title = attrs["title"][0].decode() if "title" in attrs else None
        manager_urn = f"urn:li:corpuser:{manager_ldap}" if manager_ldap else None

        mce = MetadataChangeEvent(
            proposedSnapshot=CorpUserSnapshotClass(
                urn=f"urn:li:corpuser:{ldap}",
                aspects=[
                    CorpUserInfoClass(
                        active=True,
                        email=email,
                        fullName=full_name,
                        firstName=first_name,
                        lastName=last_name,
                        departmentName=department,
                        displayName=display_name,
                        title=title,
                        managerUrn=manager_urn,
                    )
                ],
            )
        )

        return mce

    def get_report(self):
        return self.report

    def close(self):
        self.ldap_client.unbind()
Exemple #9
0
def extract_dbt_entities(
    all_manifest_entities: Dict[str, Dict[str, Any]],
    all_catalog_entities: Dict[str, Dict[str, Any]],
    sources_results: List[Dict[str, Any]],
    load_schemas: bool,
    use_identifiers: bool,
    tag_prefix: str,
    node_type_pattern: AllowDenyPattern,
    report: SourceReport,
    node_name_pattern: AllowDenyPattern,
) -> List[DBTNode]:
    sources_by_id = {x["unique_id"]: x for x in sources_results}

    dbt_entities = []
    for key, manifest_node in all_manifest_entities.items():
        # check if node pattern allowed based on config file
        if not node_type_pattern.allowed(manifest_node["resource_type"]):
            continue

        name = manifest_node["name"]
        if "identifier" in manifest_node and use_identifiers:
            name = manifest_node["identifier"]

        if manifest_node.get("alias") is not None:
            name = manifest_node["alias"]

        if not node_name_pattern.allowed(key):
            continue

        # initialize comment to "" for consistency with descriptions
        # (since dbt null/undefined descriptions as "")
        comment = ""

        if key in all_catalog_entities and all_catalog_entities[key][
                "metadata"].get("comment"):
            comment = all_catalog_entities[key]["metadata"]["comment"]

        materialization = None
        upstream_nodes = []

        if "materialized" in manifest_node.get("config", {}):
            # It's a model
            materialization = manifest_node["config"]["materialized"]
            upstream_nodes = manifest_node["depends_on"]["nodes"]

        # It's a source
        catalog_node = all_catalog_entities.get(key)
        catalog_type = None

        if catalog_node is None:
            report.report_warning(
                key,
                f"Entity {key} ({name}) is in manifest but missing from catalog",
            )
        else:
            catalog_type = all_catalog_entities[key]["metadata"]["type"]

        meta = manifest_node.get("meta", {})

        owner = meta.get("owner")
        if owner is None:
            owner = manifest_node.get("config", {}).get("meta",
                                                        {}).get("owner")

        tags = manifest_node.get("tags", [])
        tags = [tag_prefix + tag for tag in tags]
        meta_props = manifest_node.get("meta", {})
        if not meta:
            meta_props = manifest_node.get("config", {}).get("meta", {})
        dbtNode = DBTNode(
            dbt_name=key,
            database=manifest_node["database"],
            schema=manifest_node["schema"],
            name=name,
            dbt_file_path=manifest_node["original_file_path"],
            node_type=manifest_node["resource_type"],
            max_loaded_at=sources_by_id.get(key, {}).get("max_loaded_at"),
            comment=comment,
            description=manifest_node.get("description", ""),
            raw_sql=manifest_node.get("raw_sql"),
            upstream_nodes=upstream_nodes,
            materialization=materialization,
            catalog_type=catalog_type,
            columns=[],
            meta=meta_props,
            tags=tags,
            owner=owner,
        )

        # overwrite columns from catalog
        if (dbtNode.materialization !=
                "ephemeral"):  # we don't want columns if platform isn't 'dbt'
            logger.debug("Loading schema info")
            catalog_node = all_catalog_entities.get(key)

            if catalog_node is None:
                report.report_warning(
                    key,
                    f"Entity {dbtNode.dbt_name} is in manifest but missing from catalog",
                )
            else:
                dbtNode.columns = get_columns(catalog_node, manifest_node,
                                              tag_prefix)

        else:
            dbtNode.columns = []

        dbt_entities.append(dbtNode)

    return dbt_entities
Exemple #10
0
    def from_api(  # noqa: C901
        cls,
        model: str,
        explore_name: str,
        client: Looker31SDK,
        reporter: SourceReport,
    ) -> Optional["LookerExplore"]:  # noqa: C901
        try:
            explore = client.lookml_model_explore(model, explore_name)
            views = set()
            if explore.joins is not None and explore.joins != []:
                if explore.view_name is not None and explore.view_name != explore.name:
                    # explore is renaming the view name, we will need to swap references to explore.name with explore.view_name
                    aliased_explore = True
                    views.add(explore.view_name)
                else:
                    aliased_explore = False

                for e_join in [
                        e for e in explore.joins
                        if e.dependent_fields is not None
                ]:
                    assert e_join.dependent_fields is not None
                    for field_name in e_join.dependent_fields:
                        try:
                            view_name = LookerUtil._extract_view_from_field(
                                field_name)
                            if (view_name == explore.name) and aliased_explore:
                                assert explore.view_name is not None
                                view_name = explore.view_name
                            views.add(view_name)
                        except AssertionError:
                            reporter.report_warning(
                                key=f"chart-field-{field_name}",
                                reason=
                                "The field was not prefixed by a view name. This can happen when the field references another dynamic field.",
                            )
                            continue
            else:
                assert explore.view_name is not None
                views.add(explore.view_name)

            view_fields: List[ViewField] = []
            if explore.fields is not None:
                if explore.fields.dimensions is not None:
                    for dim_field in explore.fields.dimensions:
                        if dim_field.name is None:
                            continue
                        else:
                            view_fields.append(
                                ViewField(
                                    name=dim_field.name,
                                    description=dim_field.description
                                    if dim_field.description else "",
                                    type=dim_field.type
                                    if dim_field.type is not None else "",
                                    field_type=ViewFieldType.DIMENSION_GROUP
                                    if dim_field.dimension_group is not None
                                    else ViewFieldType.DIMENSION,
                                    is_primary_key=dim_field.primary_key
                                    if dim_field.primary_key else False,
                                ))
                if explore.fields.measures is not None:
                    for measure_field in explore.fields.measures:
                        if measure_field.name is None:
                            continue
                        else:
                            view_fields.append(
                                ViewField(
                                    name=measure_field.name,
                                    description=measure_field.description
                                    if measure_field.description else "",
                                    type=measure_field.type
                                    if measure_field.type is not None else "",
                                    field_type=ViewFieldType.MEASURE,
                                    is_primary_key=measure_field.primary_key
                                    if measure_field.primary_key else False,
                                ))

            return cls(
                name=explore_name,
                model_name=model,
                project_name=explore.project_name,
                label=explore.label,
                description=explore.description,
                fields=view_fields,
                upstream_views=list(views),
                source_file=explore.source_file,
            )
        except SDKError:
            logger.warn("Failed to extract explore {} from model {}.".format(
                explore_name, model))
        except AssertionError:
            reporter.report_warning(
                key="chart-",
                reason="Was unable to find dependent views for this chart",
            )
        return None
Exemple #11
0
class FeastRepositorySource(Source):
    """
    This plugin extracts:

    - Entities as [`MLPrimaryKey`](https://datahubproject.io/docs/graphql/objects#mlprimarykey)
    - Features as [`MLFeature`](https://datahubproject.io/docs/graphql/objects#mlfeature)
    - Feature views and on-demand feature views as [`MLFeatureTable`](https://datahubproject.io/docs/graphql/objects#mlfeaturetable)
    - Batch and stream source details as [`Dataset`](https://datahubproject.io/docs/graphql/objects#dataset)
    - Column types associated with each entity and feature
    """

    source_config: FeastRepositorySourceConfig
    report: SourceReport
    feature_store: FeatureStore

    def __init__(self, config: FeastRepositorySourceConfig, ctx: PipelineContext):
        super().__init__(ctx)

        self.source_config = config
        self.report = SourceReport()
        self.feature_store = FeatureStore(self.source_config.path)

    def _get_field_type(self, field_type: ValueType, parent_name: str) -> str:
        """
        Maps types encountered in Feast to corresponding schema types.
        """

        ml_feature_data_type = _field_type_mapping.get(field_type)

        if ml_feature_data_type is None:
            self.report.report_warning(
                parent_name, f"unable to map type {field_type} to metadata schema"
            )

            ml_feature_data_type = MLFeatureDataType.UNKNOWN

        return ml_feature_data_type

    def _get_data_source_details(self, source: DataSource) -> Tuple[str, str]:
        """
        Get Feast batch/stream source platform and name.
        """

        platform = "unknown"
        name = "unknown"

        if isinstance(source, FileSource):
            platform = "file"

            name = source.path.replace("://", ".").replace("/", ".")

        if isinstance(source, BigQuerySource):
            platform = "bigquery"
            name = source.table

        if isinstance(source, KafkaSource):
            platform = "kafka"
            name = source.kafka_options.topic

        if isinstance(source, KinesisSource):
            platform = "kinesis"
            name = (
                f"{source.kinesis_options.region}:{source.kinesis_options.stream_name}"
            )

        if isinstance(source, RequestDataSource):
            platform = "request"
            name = source.name

        return platform, name

    def _get_data_sources(self, feature_view: FeatureView) -> List[str]:
        """
        Get data source URN list.
        """

        sources = []

        if feature_view.batch_source is not None:
            batch_source_platform, batch_source_name = self._get_data_source_details(
                feature_view.batch_source
            )
            sources.append(
                builder.make_dataset_urn(
                    batch_source_platform,
                    batch_source_name,
                    self.source_config.environment,
                )
            )

        if feature_view.stream_source is not None:
            stream_source_platform, stream_source_name = self._get_data_source_details(
                feature_view.stream_source
            )
            sources.append(
                builder.make_dataset_urn(
                    stream_source_platform,
                    stream_source_name,
                    self.source_config.environment,
                )
            )

        return sources

    def _get_entity_workunit(
        self, feature_view: FeatureView, entity: Entity
    ) -> MetadataWorkUnit:
        """
        Generate an MLPrimaryKey work unit for a Feast entity.
        """

        feature_view_name = f"{self.feature_store.project}.{feature_view.name}"

        entity_snapshot = MLPrimaryKeySnapshot(
            urn=builder.make_ml_primary_key_urn(feature_view_name, entity.name),
            aspects=[StatusClass(removed=False)],
        )

        entity_snapshot.aspects.append(
            MLPrimaryKeyPropertiesClass(
                description=entity.description,
                dataType=self._get_field_type(entity.value_type, entity.name),
                sources=self._get_data_sources(feature_view),
            )
        )

        mce = MetadataChangeEvent(proposedSnapshot=entity_snapshot)

        return MetadataWorkUnit(id=entity.name, mce=mce)

    def _get_feature_workunit(
        self,
        feature_view: Union[FeatureView, OnDemandFeatureView],
        feature: Feature,
    ) -> MetadataWorkUnit:
        """
        Generate an MLFeature work unit for a Feast feature.
        """
        feature_view_name = f"{self.feature_store.project}.{feature_view.name}"

        feature_snapshot = MLFeatureSnapshot(
            urn=builder.make_ml_feature_urn(feature_view_name, feature.name),
            aspects=[StatusClass(removed=False)],
        )

        feature_sources = []

        if isinstance(feature_view, FeatureView):
            feature_sources = self._get_data_sources(feature_view)
        elif isinstance(feature_view, OnDemandFeatureView):
            if feature_view.input_request_data_sources is not None:
                for request_source in feature_view.input_request_data_sources.values():
                    source_platform, source_name = self._get_data_source_details(
                        request_source
                    )

                    feature_sources.append(
                        builder.make_dataset_urn(
                            source_platform,
                            source_name,
                            self.source_config.environment,
                        )
                    )

            if feature_view.input_feature_view_projections is not None:
                for (
                    feature_view_projection
                ) in feature_view.input_feature_view_projections.values():
                    feature_view_source = self.feature_store.get_feature_view(
                        feature_view_projection.name
                    )

                    feature_sources.extend(self._get_data_sources(feature_view_source))

        feature_snapshot.aspects.append(
            MLFeaturePropertiesClass(
                description=feature.labels.get("description"),
                dataType=self._get_field_type(feature.dtype, feature.name),
                sources=feature_sources,
            )
        )

        mce = MetadataChangeEvent(proposedSnapshot=feature_snapshot)

        return MetadataWorkUnit(id=feature.name, mce=mce)

    def _get_feature_view_workunit(self, feature_view: FeatureView) -> MetadataWorkUnit:
        """
        Generate an MLFeatureTable work unit for a Feast feature view.
        """

        feature_view_name = f"{self.feature_store.project}.{feature_view.name}"

        feature_view_snapshot = MLFeatureTableSnapshot(
            urn=builder.make_ml_feature_table_urn("feast", feature_view_name),
            aspects=[
                BrowsePathsClass(
                    paths=[f"/feast/{self.feature_store.project}/{feature_view_name}"]
                ),
                StatusClass(removed=False),
            ],
        )

        feature_view_snapshot.aspects.append(
            MLFeatureTablePropertiesClass(
                mlFeatures=[
                    builder.make_ml_feature_urn(
                        feature_view_name,
                        feature.name,
                    )
                    for feature in feature_view.features
                ],
                mlPrimaryKeys=[
                    builder.make_ml_primary_key_urn(feature_view_name, entity_name)
                    for entity_name in feature_view.entities
                ],
            )
        )

        mce = MetadataChangeEvent(proposedSnapshot=feature_view_snapshot)

        return MetadataWorkUnit(id=feature_view_name, mce=mce)

    def _get_on_demand_feature_view_workunit(
        self, on_demand_feature_view: OnDemandFeatureView
    ) -> MetadataWorkUnit:
        """
        Generate an MLFeatureTable work unit for a Feast on-demand feature view.
        """

        on_demand_feature_view_name = (
            f"{self.feature_store.project}.{on_demand_feature_view.name}"
        )

        on_demand_feature_view_snapshot = MLFeatureTableSnapshot(
            urn=builder.make_ml_feature_table_urn("feast", on_demand_feature_view_name),
            aspects=[
                BrowsePathsClass(
                    paths=[
                        f"/feast/{self.feature_store.project}/{on_demand_feature_view_name}"
                    ]
                ),
                StatusClass(removed=False),
            ],
        )

        on_demand_feature_view_snapshot.aspects.append(
            MLFeatureTablePropertiesClass(
                mlFeatures=[
                    builder.make_ml_feature_urn(
                        on_demand_feature_view_name,
                        feature.name,
                    )
                    for feature in on_demand_feature_view.features
                ],
                mlPrimaryKeys=[],
            )
        )

        mce = MetadataChangeEvent(proposedSnapshot=on_demand_feature_view_snapshot)

        return MetadataWorkUnit(id=on_demand_feature_view_name, mce=mce)

    @classmethod
    def create(cls, config_dict, ctx):
        config = FeastRepositorySourceConfig.parse_obj(config_dict)
        return cls(config, ctx)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        for feature_view in self.feature_store.list_feature_views():
            for entity_name in feature_view.entities:
                entity = self.feature_store.get_entity(entity_name)

                work_unit = self._get_entity_workunit(feature_view, entity)
                self.report.report_workunit(work_unit)

                yield work_unit

            for feature in feature_view.features:
                work_unit = self._get_feature_workunit(feature_view, feature)
                self.report.report_workunit(work_unit)

                yield work_unit

            work_unit = self._get_feature_view_workunit(feature_view)
            self.report.report_workunit(work_unit)

            yield work_unit

        for on_demand_feature_view in self.feature_store.list_on_demand_feature_views():
            for feature in on_demand_feature_view.features:
                work_unit = self._get_feature_workunit(on_demand_feature_view, feature)
                self.report.report_workunit(work_unit)

                yield work_unit

            work_unit = self._get_on_demand_feature_view_workunit(
                on_demand_feature_view
            )
            self.report.report_workunit(work_unit)

            yield work_unit

    def get_report(self) -> SourceReport:
        return self.report

    def close(self) -> None:
        return
Exemple #12
0
class APISource(Source, ABC):
    """

    This plugin is meant to gather dataset-like informations about OpenApi Endpoints.

    As example, if by calling GET at the endpoint at `https://test_endpoint.com/api/users/` you obtain as result:
    ```JSON
    [{"user": "******",
      "name": "Albert Einstein",
      "job": "nature declutterer",
      "is_active": true},
      {"user": "******",
      "name": "Phytagoras of Kroton",
      "job": "Phylosopher on steroids",
      "is_active": true}
    ]
    ```

    in Datahub you will see a dataset called `test_endpoint/users` which contains as fields `user`, `name` and `job`.

    """
    def __init__(self, config: OpenApiConfig, ctx: PipelineContext,
                 platform: str):
        super().__init__(ctx)
        self.config = config
        self.platform = platform
        self.report = SourceReport()
        self.url_basepath = ""

    def report_bad_responses(self, status_code: int, key: str) -> None:
        if status_code == 400:
            self.report.report_warning(
                key=key, reason="Unknown error for reaching endpoint")
        elif status_code == 403:
            self.report.report_warning(key=key,
                                       reason="Not authorised to get endpoint")
        elif status_code == 404:
            self.report.report_warning(
                key=key,
                reason=
                "Unable to find an example for endpoint. Please add it to the list of forced examples.",
            )
        elif status_code == 500:
            self.report.report_warning(
                key=key, reason="Server error for reaching endpoint")
        elif status_code == 504:
            self.report.report_warning(key=key,
                                       reason="Timeout for reaching endpoint")
        else:
            raise Exception(
                f"Unable to retrieve endpoint, response code {status_code}, key {key}"
            )

    def init_dataset(self, endpoint_k: str,
                     endpoint_dets: dict) -> Tuple[DatasetSnapshot, str]:
        config = self.config

        dataset_name = endpoint_k[1:].replace("/", ".")

        if len(dataset_name) > 0:
            if dataset_name[-1] == ".":
                dataset_name = dataset_name[:-1]
        else:
            dataset_name = "root"

        dataset_snapshot = DatasetSnapshot(
            urn=
            f"urn:li:dataset:(urn:li:dataPlatform:{self.platform},{config.name}.{dataset_name},PROD)",
            aspects=[],
        )

        # adding description
        dataset_properties = DatasetPropertiesClass(
            description=endpoint_dets["description"], customProperties={})
        dataset_snapshot.aspects.append(dataset_properties)

        # adding tags
        tags_str = [make_tag_urn(t) for t in endpoint_dets["tags"]]
        tags_tac = [TagAssociationClass(t) for t in tags_str]
        gtc = GlobalTagsClass(tags_tac)
        dataset_snapshot.aspects.append(gtc)

        # the link will appear in the "documentation"
        link_url = clean_url(config.url + self.url_basepath + endpoint_k)
        link_description = "Link to call for the dataset."
        creation = AuditStampClass(time=int(time.time()),
                                   actor="urn:li:corpuser:etl",
                                   impersonator=None)
        link_metadata = InstitutionalMemoryMetadataClass(
            url=link_url, description=link_description, createStamp=creation)
        inst_memory = InstitutionalMemoryClass([link_metadata])
        dataset_snapshot.aspects.append(inst_memory)

        return dataset_snapshot, dataset_name

    def build_wu(self, dataset_snapshot: DatasetSnapshot,
                 dataset_name: str) -> Generator[ApiWorkUnit, None, None]:
        mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
        wu = ApiWorkUnit(id=dataset_name, mce=mce)
        self.report.report_workunit(wu)
        yield wu

    def get_workunits(self) -> Iterable[ApiWorkUnit]:  # noqa: C901
        config = self.config

        sw_dict = self.config.get_swagger()

        self.url_basepath = get_url_basepath(sw_dict)

        # Getting all the URLs accepting the "GET" method
        with warnings.catch_warnings(record=True) as warn_c:
            url_endpoints = get_endpoints(sw_dict)

            for w in warn_c:
                w_msg = w.message
                w_spl = w_msg.args[0].split(" --- ")  # type: ignore
                self.report.report_warning(key=w_spl[1], reason=w_spl[0])

        # here we put a sample from the "listing endpoint". To be used for later guessing of comosed endpoints.
        root_dataset_samples = {}

        # looping on all the urls
        for endpoint_k, endpoint_dets in url_endpoints.items():
            if endpoint_k in config.ignore_endpoints:
                continue

            dataset_snapshot, dataset_name = self.init_dataset(
                endpoint_k, endpoint_dets)

            # adding dataset fields
            if "data" in endpoint_dets.keys():
                # we are lucky! data is defined in the swagger for this endpoint
                schema_metadata = set_metadata(dataset_name,
                                               endpoint_dets["data"])
                dataset_snapshot.aspects.append(schema_metadata)
                yield from self.build_wu(dataset_snapshot, dataset_name)
            elif ("{" not in endpoint_k
                  ):  # if the API does not explicitely require parameters
                tot_url = clean_url(config.url + self.url_basepath +
                                    endpoint_k)

                if config.token:
                    response = request_call(tot_url, token=config.token)
                else:
                    response = request_call(tot_url,
                                            username=config.username,
                                            password=config.password)
                if response.status_code == 200:
                    fields2add, root_dataset_samples[
                        dataset_name] = extract_fields(response, dataset_name)
                    if not fields2add:
                        self.report.report_warning(key=endpoint_k,
                                                   reason="No Fields")
                    schema_metadata = set_metadata(dataset_name, fields2add)
                    dataset_snapshot.aspects.append(schema_metadata)

                    yield from self.build_wu(dataset_snapshot, dataset_name)
                else:
                    self.report_bad_responses(response.status_code,
                                              key=endpoint_k)
            else:
                if endpoint_k not in config.forced_examples.keys():
                    # start guessing...
                    url_guess = try_guessing(endpoint_k, root_dataset_samples)
                    tot_url = clean_url(config.url + self.url_basepath +
                                        url_guess)
                    if config.token:
                        response = request_call(tot_url, token=config.token)
                    else:
                        response = request_call(tot_url,
                                                username=config.username,
                                                password=config.password)
                    if response.status_code == 200:
                        fields2add, _ = extract_fields(response, dataset_name)
                        if not fields2add:
                            self.report.report_warning(key=endpoint_k,
                                                       reason="No Fields")
                        schema_metadata = set_metadata(dataset_name,
                                                       fields2add)
                        dataset_snapshot.aspects.append(schema_metadata)

                        yield from self.build_wu(dataset_snapshot,
                                                 dataset_name)
                    else:
                        self.report_bad_responses(response.status_code,
                                                  key=endpoint_k)
                else:
                    composed_url = compose_url_attr(
                        raw_url=endpoint_k,
                        attr_list=config.forced_examples[endpoint_k])
                    tot_url = clean_url(config.url + self.url_basepath +
                                        composed_url)
                    if config.token:
                        response = request_call(tot_url, token=config.token)
                    else:
                        response = request_call(tot_url,
                                                username=config.username,
                                                password=config.password)
                    if response.status_code == 200:
                        fields2add, _ = extract_fields(response, dataset_name)
                        if not fields2add:
                            self.report.report_warning(key=endpoint_k,
                                                       reason="No Fields")
                        schema_metadata = set_metadata(dataset_name,
                                                       fields2add)
                        dataset_snapshot.aspects.append(schema_metadata)

                        yield from self.build_wu(dataset_snapshot,
                                                 dataset_name)
                    else:
                        self.report_bad_responses(response.status_code,
                                                  key=endpoint_k)

    def get_report(self):
        return self.report

    def close(self):
        pass
Exemple #13
0
class ModeSource(Source):
    config: ModeConfig
    report: SourceReport
    platform = "mode"

    def __hash__(self):
        return id(self)

    def __init__(self, ctx: PipelineContext, config: ModeConfig):
        super().__init__(ctx)
        self.config = config
        self.report = SourceReport()

        self.session = requests.session()
        self.session.auth = HTTPBasicAuth(self.config.token,
                                          self.config.password)
        self.session.headers.update({
            "Content-Type": "application/json",
            "Accept": "application/hal+json",
        })

        # Test the connection
        try:
            self._get_request_json(f"{self.config.connect_uri}/api/account")
        except HTTPError as http_error:
            self.report.report_failure(
                key="mode-session",
                reason=f"Unable to retrieve user "
                f"{self.config.token} information, "
                f"{str(http_error)}",
            )

        self.workspace_uri = f"{self.config.connect_uri}/api/{self.config.workspace}"
        self.space_tokens = self._get_space_name_and_tokens()

    def construct_dashboard(self, space_name: str,
                            report_info: dict) -> DashboardSnapshot:
        report_token = report_info.get("token", "")
        dashboard_urn = builder.make_dashboard_urn(self.platform,
                                                   report_info.get("id", ""))
        dashboard_snapshot = DashboardSnapshot(
            urn=dashboard_urn,
            aspects=[],
        )

        last_modified = ChangeAuditStamps()
        creator = self._get_creator(
            report_info.get("_links", {}).get("creator", {}).get("href", ""))
        if creator is not None:
            modified_actor = builder.make_user_urn(creator)
            modified_ts = int(
                dp.parse(
                    f"{report_info.get('last_saved_at', 'now')}").timestamp() *
                1000)
            created_ts = int(
                dp.parse(
                    f"{report_info.get('created_at', 'now')}").timestamp() *
                1000)
            title = report_info.get("name", "") or ""
            description = report_info.get("description", "") or ""
            last_modified = ChangeAuditStamps(
                created=AuditStamp(time=created_ts, actor=modified_actor),
                lastModified=AuditStamp(time=modified_ts,
                                        actor=modified_actor),
            )

        dashboard_info_class = DashboardInfoClass(
            description=description,
            title=title,
            charts=self._get_chart_urns(report_token),
            lastModified=last_modified,
            dashboardUrl=f"{self.config.connect_uri}/"
            f"{self.config.workspace}/"
            f"reports/{report_token}",
            customProperties={},
        )
        dashboard_snapshot.aspects.append(dashboard_info_class)

        # browse path
        browse_path = BrowsePathsClass(paths=[
            f"/mode/{self.config.workspace}/"
            f"{space_name}/"
            f"{report_info.get('name')}"
        ])
        dashboard_snapshot.aspects.append(browse_path)

        # Ownership
        ownership = self._get_ownership(
            self._get_creator(
                report_info.get("_links", {}).get("creator",
                                                  {}).get("href", "")))
        if ownership is not None:
            dashboard_snapshot.aspects.append(ownership)

        return dashboard_snapshot

    @lru_cache(maxsize=None)
    def _get_ownership(self, user: str) -> Optional[OwnershipClass]:
        if user is not None:
            owner_urn = builder.make_user_urn(user)
            ownership: OwnershipClass = OwnershipClass(owners=[
                OwnerClass(
                    owner=owner_urn,
                    type=OwnershipTypeClass.DATAOWNER,
                )
            ])
            return ownership

        return None

    @lru_cache(maxsize=None)
    def _get_creator(self, href: str) -> Optional[str]:
        user = None
        try:
            user_json = self._get_request_json(
                f"{self.config.connect_uri}{href}")
            user = (user_json.get("username")
                    if self.config.owner_username_instead_of_email else
                    user_json.get("email"))
        except HTTPError as http_error:
            self.report.report_failure(
                key="mode-user",
                reason=f"Unable to retrieve user for {href}, "
                f"Reason: {str(http_error)}",
            )
        return user

    def _get_chart_urns(self, report_token: str) -> list:
        chart_urns = []
        queries = self._get_queries(report_token)
        for query in queries:
            charts = self._get_charts(report_token, query.get("token", ""))
            # build chart urns
            for chart in charts:
                chart_urn = builder.make_chart_urn(self.platform,
                                                   chart.get("token", ""))
                chart_urns.append(chart_urn)

        return chart_urns

    def _get_space_name_and_tokens(self) -> dict:
        space_info = {}
        try:
            payload = self._get_request_json(f"{self.workspace_uri}/spaces")
            spaces = payload.get("_embedded", {}).get("spaces", {})

            for s in spaces:
                space_info[s.get("token", "")] = s.get("name", "")
        except HTTPError as http_error:
            self.report.report_failure(
                key="mode-spaces",
                reason=
                f"Unable to retrieve spaces/collections for {self.workspace_uri}, "
                f"Reason: {str(http_error)}",
            )

        return space_info

    def _get_chart_type(self, token: str, display_type: str) -> Optional[str]:
        type_mapping = {
            "table": ChartTypeClass.TABLE,
            "bar": ChartTypeClass.BAR,
            "line": ChartTypeClass.LINE,
            "stackedBar100": ChartTypeClass.BAR,
            "stackedBar": ChartTypeClass.BAR,
            "hStackedBar": ChartTypeClass.BAR,
            "hStackedBar100": ChartTypeClass.BAR,
            "hBar": ChartTypeClass.BAR,
            "area": ChartTypeClass.AREA,
            "totalArea": ChartTypeClass.AREA,
            "pie": ChartTypeClass.PIE,
            "donut": ChartTypeClass.PIE,
            "scatter": ChartTypeClass.SCATTER,
            "bigValue": ChartTypeClass.TEXT,
            "pivotTable": ChartTypeClass.TABLE,
            "linePlusBar": None,
        }
        if not display_type:
            self.report.report_warning(
                key=f"mode-chart-{token}",
                reason=f"Chart type {display_type} is missing. "
                f"Setting to None",
            )
            return None
        try:
            chart_type = type_mapping[display_type]
        except KeyError:
            self.report.report_warning(
                key=f"mode-chart-{token}",
                reason=f"Chart type {display_type} not supported. "
                f"Setting to None",
            )
            chart_type = None

        return chart_type

    def construct_chart_custom_properties(self, chart_detail: dict,
                                          chart_type: str) -> Dict:

        custom_properties = {}
        metadata = chart_detail.get("encoding", {})
        if chart_type == "table":
            columns = list(chart_detail.get("fieldFormats", {}).keys())
            str_columns = ",".join([c[1:-1] for c in columns])
            filters = metadata.get("filter", [])
            filters = filters[0].get("formula", "") if len(filters) else ""

            custom_properties = {
                "Columns": str_columns,
                "Filters": filters[1:-1] if len(filters) else "",
            }

        elif chart_type == "pivotTable":
            pivot_table = chart_detail.get("pivotTable", {})
            columns = pivot_table.get("columns", [])
            rows = pivot_table.get("rows", [])
            values = pivot_table.get("values", [])
            filters = pivot_table.get("filters", [])

            custom_properties = {
                "Columns": ", ".join(columns) if len(columns) else "",
                "Rows": ", ".join(rows) if len(rows) else "",
                "Metrics": ", ".join(values) if len(values) else "",
                "Filters": ", ".join(filters) if len(filters) else "",
            }
            # list filters in their own row
            for filter in filters:
                custom_properties[f"Filter: {filter}"] = ", ".join(
                    pivot_table.get("filterValues", {}).get(filter, ""))
        # Chart
        else:
            x = metadata.get("x", [])
            x2 = metadata.get("x2", [])
            y = metadata.get("y", [])
            y2 = metadata.get("y2", [])
            value = metadata.get("value", [])
            filters = metadata.get("filter", [])

            custom_properties = {
                "X": x[0].get("formula", "") if len(x) else "",
                "Y": y[0].get("formula", "") if len(y) else "",
                "X2": x2[0].get("formula", "") if len(x2) else "",
                "Y2": y2[0].get("formula", "") if len(y2) else "",
                "Metrics": value[0].get("formula", "") if len(value) else "",
                "Filters":
                filters[0].get("formula", "") if len(filters) else "",
            }

        return custom_properties

    def _get_datahub_friendly_platform(self, adapter, platform):
        # Map adaptor names to what datahub expects in
        # https://github.com/linkedin/datahub/blob/master/metadata-service/war/src/main/resources/boot/data_platforms.json

        platform_mapping = {
            "jdbc:athena": "athena",
            "jdbc:bigquery": "bigquery",
            "jdbc:druid": "druid",
            "jdbc:hive": "hive",
            "jdbc:mysql": "mysql",
            "jdbc:oracle": "oracle",
            "jdbc:postgresql": "postgres",
            "jdbc:presto": "presto",
            "jdbc:redshift": "redshift",
            "jdbc:snowflake": "snowflake",
            "jdbc:spark": "spark",
            "jdbc:sqlserver": "mssql",
            "jdbc:teradata": "teradata",
        }
        if adapter in platform_mapping:
            return platform_mapping[adapter]
        else:
            self.report.report_warning(
                key=f"mode-platform-{adapter}",
                reason=f"Platform was not found in DataHub. "
                f"Using {platform} name as is",
            )

        return platform

    @lru_cache(maxsize=None)
    def _get_platform_and_dbname(
            self,
            data_source_id: int) -> Union[Tuple[str, str], Tuple[None, None]]:

        data_sources = []
        try:
            ds_json = self._get_request_json(
                f"{self.workspace_uri}/data_sources")
            data_sources = ds_json.get("_embedded", {}).get("data_sources", [])
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-datasource-{data_source_id}",
                reason=f"No data sources found for datasource id: "
                f"{data_source_id}, "
                f"Reason: {str(http_error)}",
            )

        if not data_sources:
            self.report.report_failure(
                key=f"mode-datasource-{data_source_id}",
                reason=f"No data sources found for datasource id: "
                f"{data_source_id}",
            )
            return None, None

        for data_source in data_sources:
            if data_source.get("id", -1) == data_source_id:
                platform = self._get_datahub_friendly_platform(
                    data_source.get("adapter", ""),
                    data_source.get("name", ""))
                database = data_source.get("database", "")
                return platform, database
        else:
            self.report.report_failure(
                key=f"mode-datasource-{data_source_id}",
                reason=f"Cannot create datasource urn for datasource id: "
                f"{data_source_id}",
            )
        return None, None

    def _replace_definitions(self, raw_query: str) -> str:
        query = raw_query
        definitions = re.findall("({{[^}{]+}})", raw_query)
        for definition_variable in definitions:
            definition_name, definition_alias = self._parse_definition_name(
                definition_variable)
            definition_query = self._get_definition(definition_name)
            # if unable to retrieve definition, then replace the {{}} so that it doesn't get picked up again in recurive call
            if definition_query is not None:
                query = query.replace(
                    definition_variable,
                    f"({definition_query}) as {definition_alias}")
            else:
                query = query.replace(
                    definition_variable,
                    f"{definition_name} as {definition_alias}")
            query = self._replace_definitions(query)

        return query

    def _parse_definition_name(self,
                               definition_variable: str) -> Tuple[str, str]:
        name, alias = "", ""
        # i.e '{{ @join_on_definition as alias}}'
        name_match = re.findall("@[a-zA-z]+", definition_variable)
        if len(name_match):
            name = name_match[0][1:]
        alias_match = re.findall(
            r"as\s+\S+", definition_variable)  # i.e ['as    alias_name']
        if len(alias_match):
            alias_match = alias_match[0].split(" ")
            alias = alias_match[-1]

        return name, alias

    @lru_cache(maxsize=None)
    def _get_definition(self, definition_name):
        try:
            definition_json = self._get_request_json(
                f"{self.workspace_uri}/definitions")
            definitions = definition_json.get("_embedded",
                                              {}).get("definitions", [])
            for definition in definitions:
                if definition.get("name", "") == definition_name:
                    return definition.get("source", "")

        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-definition-{definition_name}",
                reason=f"Unable to retrieve definition for {definition_name}, "
                f"Reason: {str(http_error)}",
            )
        return None

    @lru_cache(maxsize=None)
    def _get_source_from_query(self, raw_query: str) -> set:
        query = self._replace_definitions(raw_query)
        parser = LineageRunner(query)
        source_paths = set()
        try:
            for table in parser.source_tables:
                sources = str(table).split(".")
                source_schema, source_table = sources[-2], sources[-1]
                if source_schema == "<default>":
                    source_schema = str(self.config.default_schema)

                source_paths.add(f"{source_schema}.{source_table}")
        except Exception as e:
            self.report.report_failure(
                key="mode-query",
                reason=f"Unable to retrieve lineage from query. "
                f"Query: {raw_query} "
                f"Reason: {str(e)} ",
            )

        return source_paths

    def _get_datasource_urn(self, platform, database, source_tables):
        dataset_urn = None
        if platform or database is not None:
            dataset_urn = [
                builder.make_dataset_urn(platform, f"{database}.{s_table}",
                                         self.config.env)
                for s_table in source_tables
            ]

        return dataset_urn

    def construct_chart_from_api_data(self, chart_data: dict, query: dict,
                                      path: str) -> ChartSnapshot:
        chart_urn = builder.make_chart_urn(self.platform,
                                           chart_data.get("token", ""))
        chart_snapshot = ChartSnapshot(
            urn=chart_urn,
            aspects=[],
        )

        last_modified = ChangeAuditStamps()
        creator = self._get_creator(
            chart_data.get("_links", {}).get("creator", {}).get("href", ""))
        if creator is not None:
            modified_actor = builder.make_user_urn(creator)
            created_ts = int(
                dp.parse(chart_data.get("created_at", "now")).timestamp() *
                1000)
            modified_ts = int(
                dp.parse(chart_data.get("updated_at", "now")).timestamp() *
                1000)
            last_modified = ChangeAuditStamps(
                created=AuditStamp(time=created_ts, actor=modified_actor),
                lastModified=AuditStamp(time=modified_ts,
                                        actor=modified_actor),
            )

        chart_detail = (chart_data.get("view", {})
                        if len(chart_data.get("view", {})) != 0 else
                        chart_data.get("view_vegas", {}))

        mode_chart_type = chart_detail.get(
            "chartType", "") or chart_detail.get("selectedChart", "")
        chart_type = self._get_chart_type(chart_data.get("token", ""),
                                          mode_chart_type)
        description = (chart_detail.get("description")
                       or chart_detail.get("chartDescription") or "")
        title = chart_detail.get("title") or chart_detail.get(
            "chartTitle") or ""

        # create datasource urn
        platform, db_name = self._get_platform_and_dbname(
            query.get("data_source_id"))
        source_tables = self._get_source_from_query(query.get("raw_query"))
        datasource_urn = self._get_datasource_urn(platform, db_name,
                                                  source_tables)
        custom_properties = self.construct_chart_custom_properties(
            chart_detail, mode_chart_type)

        # Chart Info
        chart_info = ChartInfoClass(
            type=chart_type,
            description=description,
            title=title,
            lastModified=last_modified,
            chartUrl=f"{self.config.connect_uri}"
            f"{chart_data.get('_links', {}).get('report_viz_web', {}).get('href', '')}",
            inputs=datasource_urn,
            customProperties=custom_properties,
        )
        chart_snapshot.aspects.append(chart_info)

        # Browse Path
        browse_path = BrowsePathsClass(paths=[path])
        chart_snapshot.aspects.append(browse_path)

        # Query
        chart_query = ChartQueryClass(
            rawQuery=query.get("raw_query", ""),
            type=ChartQueryTypeClass.SQL,
        )
        chart_snapshot.aspects.append(chart_query)

        # Ownership
        ownership = self._get_ownership(
            self._get_creator(
                chart_data.get("_links", {}).get("creator",
                                                 {}).get("href", "")))
        if ownership is not None:
            chart_snapshot.aspects.append(ownership)

        return chart_snapshot

    @lru_cache(maxsize=None)
    def _get_reports(self, space_token: str) -> list:
        reports = []
        try:
            reports_json = self._get_request_json(
                f"{self.workspace_uri}/spaces/{space_token}/reports")
            reports = reports_json.get("_embedded", {}).get("reports", {})
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-report-{space_token}",
                reason=
                f"Unable to retrieve reports for space token: {space_token}, "
                f"Reason: {str(http_error)}",
            )
        return reports

    @lru_cache(maxsize=None)
    def _get_queries(self, report_token: str) -> list:
        queries = []
        try:
            queries_json = self._get_request_json(
                f"{self.workspace_uri}/reports/{report_token}/queries")
            queries = queries_json.get("_embedded", {}).get("queries", {})
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-query-{report_token}",
                reason=
                f"Unable to retrieve queries for report token: {report_token}, "
                f"Reason: {str(http_error)}",
            )
        return queries

    @lru_cache(maxsize=None)
    def _get_charts(self, report_token: str, query_token: str) -> list:
        charts = []
        try:
            charts_json = self._get_request_json(
                f"{self.workspace_uri}/reports/{report_token}"
                f"/queries/{query_token}/charts")
            charts = charts_json.get("_embedded", {}).get("charts", {})
        except HTTPError as http_error:
            self.report.report_failure(
                key=f"mode-chart-{report_token}-{query_token}",
                reason=f"Unable to retrieve charts: "
                f"Report token: {report_token} "
                f"Query token: {query_token}, "
                f"Reason: {str(http_error)}",
            )
        return charts

    def _get_request_json(self, url: str) -> Dict:
        r = tenacity.Retrying(
            wait=wait_exponential(
                multiplier=self.config.api_options.retry_backoff_multiplier,
                max=self.config.api_options.max_retry_interval,
            ),
            retry=retry_if_exception_type(HTTPError429),
            stop=stop_after_attempt(self.config.api_options.max_attempts),
        )

        @r.wraps
        def get_request():
            try:
                response = self.session.get(url)
                response.raise_for_status()
                return response.json()
            except HTTPError as http_error:
                error_response = http_error.response
                if error_response.status_code == 429:
                    # respect Retry-After
                    sleep_time = error_response.headers.get("retry-after")
                    if sleep_time is not None:
                        time.sleep(sleep_time)
                    raise HTTPError429

                raise http_error
            return {}

        return get_request()

    def emit_dashboard_mces(self) -> Iterable[MetadataWorkUnit]:
        for space_token, space_name in self.space_tokens.items():
            reports = self._get_reports(space_token)
            for report in reports:
                dashboard_snapshot_from_report = self.construct_dashboard(
                    space_name, report)

                mce = MetadataChangeEvent(
                    proposedSnapshot=dashboard_snapshot_from_report)
                wu = MetadataWorkUnit(id=dashboard_snapshot_from_report.urn,
                                      mce=mce)
                self.report.report_workunit(wu)

                yield wu

    def emit_chart_mces(self) -> Iterable[MetadataWorkUnit]:
        # Space/collection -> report -> query -> Chart
        for space_token, space_name in self.space_tokens.items():
            reports = self._get_reports(space_token)
            for report in reports:
                report_token = report.get("token", "")
                queries = self._get_queries(report_token)
                for query in queries:
                    charts = self._get_charts(report_token,
                                              query.get("token", ""))
                    # build charts
                    for chart in charts:
                        view = chart.get("view") or chart.get("view_vegas")
                        chart_name = view.get("title") or view.get(
                            "chartTitle") or ""
                        path = (f"/mode/{self.config.workspace}/{space_name}"
                                f"/{report.get('name')}/{query.get('name')}/"
                                f"{chart_name}")
                        chart_snapshot = self.construct_chart_from_api_data(
                            chart, query, path)
                        mce = MetadataChangeEvent(
                            proposedSnapshot=chart_snapshot)
                        wu = MetadataWorkUnit(id=chart_snapshot.urn, mce=mce)
                        self.report.report_workunit(wu)

                        yield wu

    @classmethod
    def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
        config = ModeConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        yield from self.emit_dashboard_mces()
        yield from self.emit_chart_mces()

    def get_report(self) -> SourceReport:
        return self.report
Exemple #14
0
class TableauSource(Source):
    config: TableauConfig
    report: SourceReport
    platform = "tableau"
    server: Server
    upstream_tables: Dict[str, Tuple[Any, str]] = {}

    def __hash__(self):
        return id(self)

    def __init__(self, ctx: PipelineContext, config: TableauConfig):
        super().__init__(ctx)

        self.config = config
        self.report = SourceReport()
        self.server = None
        # This list keeps track of datasource being actively used by workbooks so that we only retrieve those
        # when emitting published data sources.
        self.datasource_ids_being_used: List[str] = []
        # This list keeps track of datasource being actively used by workbooks so that we only retrieve those
        # when emitting custom SQL data sources.
        self.custom_sql_ids_being_used: List[str] = []

        self._authenticate()

    def close(self) -> None:
        if self.server is not None:
            self.server.auth.sign_out()

    def _authenticate(self):
        # https://tableau.github.io/server-client-python/docs/api-ref#authentication
        authentication = None
        if self.config.username and self.config.password:
            authentication = TableauAuth(
                username=self.config.username,
                password=self.config.password,
                site_id=self.config.site,
            )
        elif self.config.token_name and self.config.token_value:
            authentication = PersonalAccessTokenAuth(self.config.token_name,
                                                     self.config.token_value,
                                                     self.config.site)
        else:
            raise ConfigurationError(
                "Tableau Source: Either username/password or token_name/token_value must be set"
            )

        try:
            self.server = Server(self.config.connect_uri,
                                 use_server_version=True)
            self.server.auth.sign_in(authentication)
        except ServerResponseError as e:
            logger.error(e)
            self.report.report_failure(
                key="tableau-login",
                reason=f"Unable to Login with credentials provided"
                f"Reason: {str(e)}",
            )
        except Exception as e:
            logger.error(e)
            self.report.report_failure(key="tableau-login",
                                       reason=f"Unable to Login"
                                       f"Reason: {str(e)}")

    def get_connection_object(
        self,
        query: str,
        connection_type: str,
        query_filter: str,
        count: int = 0,
        current_count: int = 0,
    ) -> Tuple[dict, int, int]:
        query_data = query_metadata(self.server, query, connection_type, count,
                                    current_count, query_filter)

        if "errors" in query_data:
            self.report.report_warning(
                key="tableau-metadata",
                reason=
                f"Connection: {connection_type} Error: {query_data['errors']}",
            )

        connection_object = (query_data.get("data").get(connection_type, {})
                             if query_data.get("data") else {})

        total_count = connection_object.get("totalCount", 0)
        has_next_page = connection_object.get("pageInfo",
                                              {}).get("hasNextPage", False)
        return connection_object, total_count, has_next_page

    def emit_workbooks(self,
                       workbooks_page_size: int) -> Iterable[MetadataWorkUnit]:

        projects = (f"projectNameWithin: {json.dumps(self.config.projects)}"
                    if self.config.projects else "")

        workbook_connection, total_count, has_next_page = self.get_connection_object(
            workbook_graphql_query, "workbooksConnection", projects)

        current_count = 0
        while has_next_page:
            count = (workbooks_page_size if current_count +
                     workbooks_page_size < total_count else total_count -
                     current_count)
            (
                workbook_connection,
                total_count,
                has_next_page,
            ) = self.get_connection_object(
                workbook_graphql_query,
                "workbooksConnection",
                projects,
                count,
                current_count,
            )

            current_count += count

            for workbook in workbook_connection.get("nodes", []):
                yield from self.emit_workbook_as_container(workbook)
                yield from self.emit_sheets_as_charts(workbook)
                yield from self.emit_dashboards(workbook)
                yield from self.emit_embedded_datasource(workbook)
                yield from self.emit_upstream_tables()

    def _track_custom_sql_ids(self, field: dict) -> None:
        # Tableau shows custom sql datasource as a table in ColumnField.
        if field.get("__typename", "") == "ColumnField":
            for column in field.get("columns", []):
                table_id = column.get("table", {}).get("id")

                if (table_id is not None
                        and table_id not in self.custom_sql_ids_being_used):
                    self.custom_sql_ids_being_used.append(table_id)

    def _create_upstream_table_lineage(
            self,
            datasource: dict,
            project: str,
            is_custom_sql: bool = False) -> List[UpstreamClass]:
        upstream_tables = []

        for table in datasource.get("upstreamTables", []):
            # skip upstream tables when there is no column info when retrieving embedded datasource
            # and when table name is None
            # Schema details for these will be taken care in self.emit_custom_sql_ds()
            if not is_custom_sql and not table.get("columns"):
                continue
            elif table["name"] is None:
                continue

            upstream_db = table.get("database", {}).get("name", "")
            schema = self._get_schema(table.get("schema", ""), upstream_db)
            table_urn = make_table_urn(
                self.config.env,
                upstream_db,
                table.get("connectionType", ""),
                schema,
                table.get("name", ""),
            )

            upstream_table = UpstreamClass(
                dataset=table_urn,
                type=DatasetLineageTypeClass.TRANSFORMED,
            )
            upstream_tables.append(upstream_table)
            table_path = f"{project.replace('/', REPLACE_SLASH_CHAR)}/{datasource.get('name', '')}/{table.get('name', '')}"
            self.upstream_tables[table_urn] = (
                table.get("columns", []),
                table_path,
            )

        for datasource in datasource.get("upstreamDatasources", []):
            datasource_urn = builder.make_dataset_urn(self.platform,
                                                      datasource["id"],
                                                      self.config.env)
            upstream_table = UpstreamClass(
                dataset=datasource_urn,
                type=DatasetLineageTypeClass.TRANSFORMED,
            )
            upstream_tables.append(upstream_table)

        return upstream_tables

    def emit_custom_sql_datasources(self) -> Iterable[MetadataWorkUnit]:
        count_on_query = len(self.custom_sql_ids_being_used)
        custom_sql_filter = "idWithin: {}".format(
            json.dumps(self.custom_sql_ids_being_used))
        custom_sql_connection, total_count, has_next_page = self.get_connection_object(
            custom_sql_graphql_query, "customSQLTablesConnection",
            custom_sql_filter)

        current_count = 0
        while has_next_page:
            count = (count_on_query if current_count +
                     count_on_query < total_count else total_count -
                     current_count)
            (
                custom_sql_connection,
                total_count,
                has_next_page,
            ) = self.get_connection_object(
                custom_sql_graphql_query,
                "customSQLTablesConnection",
                custom_sql_filter,
                count,
                current_count,
            )
            current_count += count

            unique_custom_sql = get_unique_custom_sql(
                custom_sql_connection.get("nodes", []))
            for csql in unique_custom_sql:
                csql_id: str = csql.get("id", "")
                csql_urn = builder.make_dataset_urn(self.platform, csql_id,
                                                    self.config.env)
                dataset_snapshot = DatasetSnapshot(
                    urn=csql_urn,
                    aspects=[],
                )

                # lineage from datasource -> custom sql source #
                yield from self._create_lineage_from_csql_datasource(
                    csql_urn, csql.get("datasources", []))

                # lineage from custom sql -> datasets/tables #
                columns = csql.get("columns", [])
                yield from self._create_lineage_to_upstream_tables(
                    csql_urn, columns)

                #  Schema Metadata
                schema_metadata = self.get_schema_metadata_for_custom_sql(
                    columns)
                if schema_metadata is not None:
                    dataset_snapshot.aspects.append(schema_metadata)

                # Browse path
                browse_paths = BrowsePathsClass(paths=[
                    f"/{self.config.env.lower()}/{self.platform}/Custom SQL/{csql.get('name', '')}/{csql_id}"
                ])
                dataset_snapshot.aspects.append(browse_paths)

                dataset_properties = DatasetPropertiesClass(
                    name=csql.get("name"), description=csql.get("description"))

                dataset_snapshot.aspects.append(dataset_properties)

                view_properties = ViewPropertiesClass(
                    materialized=False,
                    viewLanguage="SQL",
                    viewLogic=clean_query(csql.get("query", "")),
                )
                dataset_snapshot.aspects.append(view_properties)

                yield self.get_metadata_change_event(dataset_snapshot)
                yield self.get_metadata_change_proposal(
                    dataset_snapshot.urn,
                    aspect_name="subTypes",
                    aspect=SubTypesClass(typeNames=["View", "Custom SQL"]),
                )

    def get_schema_metadata_for_custom_sql(
            self, columns: List[dict]) -> Optional[SchemaMetadata]:
        schema_metadata = None
        for field in columns:
            # Datasource fields
            fields = []
            nativeDataType = field.get("remoteType", "UNKNOWN")
            TypeClass = FIELD_TYPE_MAPPING.get(nativeDataType, NullTypeClass)
            schema_field = SchemaField(
                fieldPath=field.get("name", ""),
                type=SchemaFieldDataType(type=TypeClass()),
                nativeDataType=nativeDataType,
                description=field.get("description", ""),
            )
            fields.append(schema_field)

            schema_metadata = SchemaMetadata(
                schemaName="test",
                platform=f"urn:li:dataPlatform:{self.platform}",
                version=0,
                fields=fields,
                hash="",
                platformSchema=OtherSchema(rawSchema=""),
            )
        return schema_metadata

    def _create_lineage_from_csql_datasource(
            self, csql_urn: str,
            csql_datasource: List[dict]) -> Iterable[MetadataWorkUnit]:
        for datasource in csql_datasource:
            datasource_urn = builder.make_dataset_urn(self.platform,
                                                      datasource.get("id", ""),
                                                      self.config.env)
            upstream_csql = UpstreamClass(
                dataset=csql_urn,
                type=DatasetLineageTypeClass.TRANSFORMED,
            )

            upstream_lineage = UpstreamLineage(upstreams=[upstream_csql])
            yield self.get_metadata_change_proposal(
                datasource_urn,
                aspect_name="upstreamLineage",
                aspect=upstream_lineage)

    def _create_lineage_to_upstream_tables(
            self, csql_urn: str,
            columns: List[dict]) -> Iterable[MetadataWorkUnit]:
        used_datasources = []
        # Get data sources from columns' reference fields.
        for field in columns:
            data_sources = [
                reference.get("datasource")
                for reference in field.get("referencedByFields", {})
                if reference.get("datasource") is not None
            ]

            for datasource in data_sources:
                if datasource.get("id", "") in used_datasources:
                    continue
                used_datasources.append(datasource.get("id", ""))
                upstream_tables = self._create_upstream_table_lineage(
                    datasource,
                    datasource.get("workbook", {}).get("projectName", ""),
                    True,
                )
                if upstream_tables:
                    upstream_lineage = UpstreamLineage(
                        upstreams=upstream_tables)
                    yield self.get_metadata_change_proposal(
                        csql_urn,
                        aspect_name="upstreamLineage",
                        aspect=upstream_lineage,
                    )

    def _get_schema_metadata_for_embedded_datasource(
            self, datasource_fields: List[dict]) -> Optional[SchemaMetadata]:
        fields = []
        schema_metadata = None
        for field in datasource_fields:
            # check datasource - custom sql relations from a field being referenced
            self._track_custom_sql_ids(field)

            nativeDataType = field.get("dataType", "UNKNOWN")
            TypeClass = FIELD_TYPE_MAPPING.get(nativeDataType, NullTypeClass)

            schema_field = SchemaField(
                fieldPath=field["name"],
                type=SchemaFieldDataType(type=TypeClass()),
                description=make_description_from_params(
                    field.get("description", ""), field.get("formula")),
                nativeDataType=nativeDataType,
                globalTags=get_tags_from_params([
                    field.get("role", ""),
                    field.get("__typename", ""),
                    field.get("aggregation", ""),
                ]) if self.config.ingest_tags else None,
            )
            fields.append(schema_field)

        if fields:
            schema_metadata = SchemaMetadata(
                schemaName="test",
                platform=f"urn:li:dataPlatform:{self.platform}",
                version=0,
                fields=fields,
                hash="",
                platformSchema=OtherSchema(rawSchema=""),
            )

        return schema_metadata

    def get_metadata_change_event(
        self, snap_shot: Union["DatasetSnapshot", "DashboardSnapshot",
                               "ChartSnapshot"]
    ) -> MetadataWorkUnit:
        mce = MetadataChangeEvent(proposedSnapshot=snap_shot)
        work_unit = MetadataWorkUnit(id=snap_shot.urn, mce=mce)
        self.report.report_workunit(work_unit)
        return work_unit

    def get_metadata_change_proposal(
        self,
        urn: str,
        aspect_name: str,
        aspect: Union["UpstreamLineage", "SubTypesClass"],
    ) -> MetadataWorkUnit:
        mcp = MetadataChangeProposalWrapper(
            entityType="dataset",
            changeType=ChangeTypeClass.UPSERT,
            entityUrn=urn,
            aspectName=aspect_name,
            aspect=aspect,
        )
        mcp_workunit = MetadataWorkUnit(
            id=f"tableau-{mcp.entityUrn}-{mcp.aspectName}",
            mcp=mcp,
            treat_errors_as_warnings=True,
        )
        self.report.report_workunit(mcp_workunit)
        return mcp_workunit

    def emit_datasource(self,
                        datasource: dict,
                        workbook: dict = None) -> Iterable[MetadataWorkUnit]:
        datasource_info = workbook
        if workbook is None:
            datasource_info = datasource

        project = (datasource_info.get("projectName", "").replace(
            "/", REPLACE_SLASH_CHAR) if datasource_info else "")
        datasource_id = datasource.get("id", "")
        datasource_name = f"{datasource.get('name')}.{datasource_id}"
        datasource_urn = builder.make_dataset_urn(self.platform, datasource_id,
                                                  self.config.env)
        if datasource_id not in self.datasource_ids_being_used:
            self.datasource_ids_being_used.append(datasource_id)

        dataset_snapshot = DatasetSnapshot(
            urn=datasource_urn,
            aspects=[],
        )

        # Browse path
        browse_paths = BrowsePathsClass(paths=[
            f"/{self.config.env.lower()}/{self.platform}/{project}/{datasource.get('name', '')}/{datasource_name}"
        ])
        dataset_snapshot.aspects.append(browse_paths)

        # Ownership
        owner = (self._get_ownership(
            datasource_info.get("owner", {}).get("username", ""))
                 if datasource_info else None)
        if owner is not None:
            dataset_snapshot.aspects.append(owner)

        # Dataset properties
        dataset_props = DatasetPropertiesClass(
            name=datasource.get("name"),
            description=datasource.get("description"),
            customProperties={
                "hasExtracts":
                str(datasource.get("hasExtracts", "")),
                "extractLastRefreshTime":
                datasource.get("extractLastRefreshTime", "") or "",
                "extractLastIncrementalUpdateTime":
                datasource.get("extractLastIncrementalUpdateTime", "") or "",
                "extractLastUpdateTime":
                datasource.get("extractLastUpdateTime", "") or "",
                "type":
                datasource.get("__typename", ""),
            },
        )
        dataset_snapshot.aspects.append(dataset_props)

        # Upstream Tables
        if datasource.get("upstreamTables") is not None:
            # datasource -> db table relations
            upstream_tables = self._create_upstream_table_lineage(
                datasource, project)

            if upstream_tables:
                upstream_lineage = UpstreamLineage(upstreams=upstream_tables)
                yield self.get_metadata_change_proposal(
                    datasource_urn,
                    aspect_name="upstreamLineage",
                    aspect=upstream_lineage,
                )

        # Datasource Fields
        schema_metadata = self._get_schema_metadata_for_embedded_datasource(
            datasource.get("fields", []))
        if schema_metadata is not None:
            dataset_snapshot.aspects.append(schema_metadata)

        yield self.get_metadata_change_event(dataset_snapshot)
        yield self.get_metadata_change_proposal(
            dataset_snapshot.urn,
            aspect_name="subTypes",
            aspect=SubTypesClass(typeNames=["Data Source"]),
        )

        if datasource.get("__typename") == "EmbeddedDatasource":
            yield from add_entity_to_container(self.gen_workbook_key(workbook),
                                               "dataset", dataset_snapshot.urn)

    def emit_published_datasources(self) -> Iterable[MetadataWorkUnit]:
        count_on_query = len(self.datasource_ids_being_used)
        datasource_filter = "idWithin: {}".format(
            json.dumps(self.datasource_ids_being_used))
        (
            published_datasource_conn,
            total_count,
            has_next_page,
        ) = self.get_connection_object(
            published_datasource_graphql_query,
            "publishedDatasourcesConnection",
            datasource_filter,
        )

        current_count = 0
        while has_next_page:
            count = (count_on_query if current_count +
                     count_on_query < total_count else total_count -
                     current_count)
            (
                published_datasource_conn,
                total_count,
                has_next_page,
            ) = self.get_connection_object(
                published_datasource_graphql_query,
                "publishedDatasourcesConnection",
                datasource_filter,
                count,
                current_count,
            )

            current_count += count
            for datasource in published_datasource_conn.get("nodes", []):
                yield from self.emit_datasource(datasource)

    def emit_upstream_tables(self) -> Iterable[MetadataWorkUnit]:
        for (table_urn, (columns, path)) in self.upstream_tables.items():
            dataset_snapshot = DatasetSnapshot(
                urn=table_urn,
                aspects=[],
            )
            # Browse path
            browse_paths = BrowsePathsClass(
                paths=[f"/{self.config.env.lower()}/{self.platform}/{path}"])
            dataset_snapshot.aspects.append(browse_paths)

            fields = []
            for field in columns:
                nativeDataType = field.get("remoteType", "UNKNOWN")
                TypeClass = FIELD_TYPE_MAPPING.get(nativeDataType,
                                                   NullTypeClass)

                schema_field = SchemaField(
                    fieldPath=field["name"],
                    type=SchemaFieldDataType(type=TypeClass()),
                    description="",
                    nativeDataType=nativeDataType,
                )

                fields.append(schema_field)

            schema_metadata = SchemaMetadata(
                schemaName="test",
                platform=f"urn:li:dataPlatform:{self.platform}",
                version=0,
                fields=fields,
                hash="",
                platformSchema=OtherSchema(rawSchema=""),
            )
            if schema_metadata is not None:
                dataset_snapshot.aspects.append(schema_metadata)

            yield self.get_metadata_change_event(dataset_snapshot)

    # Older tableau versions do not support fetching sheet's upstreamDatasources,
    # This achieves the same effect by using datasource's downstreamSheets
    def get_sheetwise_upstream_datasources(self, workbook: dict) -> dict:
        sheet_upstream_datasources: dict = {}

        for embedded_ds in workbook.get("embeddedDatasources", []):
            for sheet in embedded_ds.get("downstreamSheets", []):
                if sheet.get("id") not in sheet_upstream_datasources:
                    sheet_upstream_datasources[sheet.get("id")] = set()
                sheet_upstream_datasources[sheet.get("id")].add(
                    embedded_ds.get("id"))

        for published_ds in workbook.get("upstreamDatasources", []):
            for sheet in published_ds.get("downstreamSheets", []):
                if sheet.get("id") not in sheet_upstream_datasources:
                    sheet_upstream_datasources[sheet.get("id")] = set()
                sheet_upstream_datasources[sheet.get("id")].add(
                    published_ds.get("id"))
        return sheet_upstream_datasources

    def emit_sheets_as_charts(self,
                              workbook: Dict) -> Iterable[MetadataWorkUnit]:
        sheet_upstream_datasources = self.get_sheetwise_upstream_datasources(
            workbook)
        for sheet in workbook.get("sheets", []):
            chart_snapshot = ChartSnapshot(
                urn=builder.make_chart_urn(self.platform, sheet.get("id")),
                aspects=[],
            )

            creator = workbook.get("owner", {}).get("username", "")
            created_at = sheet.get("createdAt", datetime.now())
            updated_at = sheet.get("updatedAt", datetime.now())
            last_modified = self.get_last_modified(creator, created_at,
                                                   updated_at)

            if sheet.get("path"):
                site_part = f"/site/{self.config.site}" if self.config.site else ""
                sheet_external_url = (
                    f"{self.config.connect_uri}/#{site_part}/views/{sheet.get('path')}"
                )
            elif sheet.get("containedInDashboards"):
                # sheet contained in dashboard
                site_part = f"/t/{self.config.site}" if self.config.site else ""
                dashboard_path = sheet.get("containedInDashboards")[0].get(
                    "path", "")
                sheet_external_url = f"{self.config.connect_uri}{site_part}/authoring/{dashboard_path}/{sheet.get('name', '')}"
            else:
                # hidden or viz-in-tooltip sheet
                sheet_external_url = None
            fields = {}
            for field in sheet.get("datasourceFields", ""):
                description = make_description_from_params(
                    get_field_value_in_sheet(field, "description"),
                    get_field_value_in_sheet(field, "formula"),
                )
                fields[get_field_value_in_sheet(field, "name")] = description

            # datasource urn
            datasource_urn = []
            data_sources = sheet_upstream_datasources.get(
                sheet.get("id"), set())

            for ds_id in data_sources:
                if ds_id is None or not ds_id:
                    continue
                ds_urn = builder.make_dataset_urn(self.platform, ds_id,
                                                  self.config.env)
                datasource_urn.append(ds_urn)
                if ds_id not in self.datasource_ids_being_used:
                    self.datasource_ids_being_used.append(ds_id)

            # Chart Info
            chart_info = ChartInfoClass(
                description="",
                title=sheet.get("name", ""),
                lastModified=last_modified,
                externalUrl=sheet_external_url,
                inputs=sorted(datasource_urn),
                customProperties=fields,
            )
            chart_snapshot.aspects.append(chart_info)

            # Browse path
            browse_path = BrowsePathsClass(paths=[
                f"/{self.platform}/{workbook.get('projectName', '').replace('/', REPLACE_SLASH_CHAR)}"
                f"/{workbook.get('name', '')}"
                f"/{sheet.get('name', '').replace('/', REPLACE_SLASH_CHAR)}"
            ])
            chart_snapshot.aspects.append(browse_path)

            # Ownership
            owner = self._get_ownership(creator)
            if owner is not None:
                chart_snapshot.aspects.append(owner)

            #  Tags
            tag_list = sheet.get("tags", [])
            if tag_list and self.config.ingest_tags:
                tag_list_str = [
                    t.get("name", "").upper() for t in tag_list
                    if t is not None
                ]
                chart_snapshot.aspects.append(
                    builder.make_global_tag_aspect_with_tag_list(tag_list_str))

            yield self.get_metadata_change_event(chart_snapshot)

            yield from add_entity_to_container(self.gen_workbook_key(workbook),
                                               "chart", chart_snapshot.urn)

    def emit_workbook_as_container(
            self, workbook: Dict) -> Iterable[MetadataWorkUnit]:

        workbook_container_key = self.gen_workbook_key(workbook)
        creator = workbook.get("owner", {}).get("username")

        owner_urn = (builder.make_user_urn(creator) if
                     (creator and self.config.ingest_owner) else None)

        site_part = f"/site/{self.config.site}" if self.config.site else ""
        workbook_uri = workbook.get("uri", "")
        workbook_part = (workbook_uri[workbook_uri.index("/workbooks/"):]
                         if workbook.get("uri") else None)
        workbook_external_url = (
            f"{self.config.connect_uri}/#{site_part}{workbook_part}"
            if workbook_part else None)

        tag_list = workbook.get("tags", [])
        tag_list_str = (
            [t.get("name", "").upper() for t in tag_list if t is not None] if
            (tag_list and self.config.ingest_tags) else None)

        container_workunits = gen_containers(
            container_key=workbook_container_key,
            name=workbook.get("name", ""),
            sub_types=["Workbook"],
            description=workbook.get("description"),
            owner_urn=owner_urn,
            external_url=workbook_external_url,
            tags=tag_list_str,
        )

        for wu in container_workunits:
            self.report.report_workunit(wu)
            yield wu

    def gen_workbook_key(self, workbook):
        return WorkbookKey(platform=self.platform,
                           instance=None,
                           workbook_id=workbook["id"])

    def emit_dashboards(self, workbook: Dict) -> Iterable[MetadataWorkUnit]:
        for dashboard in workbook.get("dashboards", []):
            dashboard_snapshot = DashboardSnapshot(
                urn=builder.make_dashboard_urn(self.platform,
                                               dashboard.get("id", "")),
                aspects=[],
            )

            creator = workbook.get("owner", {}).get("username", "")
            created_at = dashboard.get("createdAt", datetime.now())
            updated_at = dashboard.get("updatedAt", datetime.now())
            last_modified = self.get_last_modified(creator, created_at,
                                                   updated_at)

            site_part = f"/site/{self.config.site}" if self.config.site else ""
            dashboard_external_url = f"{self.config.connect_uri}/#{site_part}/views/{dashboard.get('path', '')}"
            title = dashboard.get("name", "").replace("/",
                                                      REPLACE_SLASH_CHAR) or ""
            chart_urns = [
                builder.make_chart_urn(self.platform, sheet.get("id"))
                for sheet in dashboard.get("sheets", [])
            ]
            dashboard_info_class = DashboardInfoClass(
                description="",
                title=title,
                charts=chart_urns,
                lastModified=last_modified,
                dashboardUrl=dashboard_external_url,
                customProperties={},
            )
            dashboard_snapshot.aspects.append(dashboard_info_class)

            # browse path
            browse_paths = BrowsePathsClass(paths=[
                f"/{self.platform}/{workbook.get('projectName', '').replace('/', REPLACE_SLASH_CHAR)}"
                f"/{workbook.get('name', '').replace('/', REPLACE_SLASH_CHAR)}"
                f"/{title}"
            ])
            dashboard_snapshot.aspects.append(browse_paths)

            # Ownership
            owner = self._get_ownership(creator)
            if owner is not None:
                dashboard_snapshot.aspects.append(owner)

            yield self.get_metadata_change_event(dashboard_snapshot)

            yield from add_entity_to_container(self.gen_workbook_key(workbook),
                                               "dashboard",
                                               dashboard_snapshot.urn)

    def emit_embedded_datasource(self,
                                 workbook: Dict) -> Iterable[MetadataWorkUnit]:
        for datasource in workbook.get("embeddedDatasources", []):
            yield from self.emit_datasource(datasource, workbook)

    @lru_cache(maxsize=None)
    def _get_schema(self, schema_provided: str, database: str) -> str:
        schema = schema_provided
        if not schema_provided and database in self.config.default_schema_map:
            schema = self.config.default_schema_map[database]

        return schema

    @lru_cache(maxsize=None)
    def get_last_modified(self, creator: str, created_at: bytes,
                          updated_at: bytes) -> ChangeAuditStamps:
        last_modified = ChangeAuditStamps()
        if creator:
            modified_actor = builder.make_user_urn(creator)
            created_ts = int(dp.parse(created_at).timestamp() * 1000)
            modified_ts = int(dp.parse(updated_at).timestamp() * 1000)
            last_modified = ChangeAuditStamps(
                created=AuditStamp(time=created_ts, actor=modified_actor),
                lastModified=AuditStamp(time=modified_ts,
                                        actor=modified_actor),
            )
        return last_modified

    @lru_cache(maxsize=None)
    def _get_ownership(self, user: str) -> Optional[OwnershipClass]:
        if self.config.ingest_owner and user:
            owner_urn = builder.make_user_urn(user)
            ownership: OwnershipClass = OwnershipClass(owners=[
                OwnerClass(
                    owner=owner_urn,
                    type=OwnershipTypeClass.DATAOWNER,
                )
            ])
            return ownership

        return None

    @classmethod
    def create(cls, config_dict: dict, ctx: PipelineContext) -> Source:
        config = TableauConfig.parse_obj(config_dict)
        return cls(ctx, config)

    def get_workunits(self) -> Iterable[MetadataWorkUnit]:
        if self.server is None or not self.server.is_signed_in():
            return
        try:
            yield from self.emit_workbooks(self.config.workbooks_page_size)
            if self.datasource_ids_being_used:
                yield from self.emit_published_datasources()
            if self.custom_sql_ids_being_used:
                yield from self.emit_custom_sql_datasources()
        except MetadataQueryException as md_exception:
            self.report.report_failure(
                key="tableau-metadata",
                reason=
                f"Unable to retrieve metadata from tableau. Information: {str(md_exception)}",
            )

    def get_report(self) -> SourceReport:
        return self.report
Exemple #15
0
def extract_dbt_entities(
    all_manifest_entities: Dict[str, Dict[str, Any]],
    all_catalog_entities: Dict[str, Dict[str, Any]],
    sources_results: List[Dict[str, Any]],
    load_catalog: bool,
    target_platform: str,
    environment: str,
    node_type_pattern: AllowDenyPattern,
    report: SourceReport,
) -> List[DBTNode]:

    sources_by_id = {x["unique_id"]: x for x in sources_results}

    dbt_entities = []
    for key, node in all_manifest_entities.items():
        # check if node pattern allowed based on config file
        if not node_type_pattern.allowed(node["resource_type"]):
            continue

        name = node["name"]

        if "identifier" in node and not load_catalog:
            name = node["identifier"]

        materialization = None
        upstream_urns = []

        if "materialized" in node.get("config", {}).keys():
            # It's a model
            materialization = node["config"]["materialized"]
            upstream_urns = get_upstreams(
                node["depends_on"]["nodes"],
                all_manifest_entities,
                load_catalog,
                target_platform,
                environment,
            )

        # It's a source
        catalog_node = all_catalog_entities.get(key)
        catalog_type = None

        if catalog_node is None:
            report.report_warning(
                key,
                f"Entity {name} is in manifest but missing from catalog",
            )

        else:

            catalog_type = all_catalog_entities[key]["metadata"]["type"]

        dbtNode = DBTNode(
            dbt_name=key,
            database=node["database"],
            schema=node["schema"],
            dbt_file_path=node["original_file_path"],
            node_type=node["resource_type"],
            max_loaded_at=sources_by_id.get(key, {}).get("max_loaded_at"),
            name=name,
            upstream_urns=upstream_urns,
            materialization=materialization,
            catalog_type=catalog_type,
            columns=[],
            datahub_urn=get_urn_from_dbtNode(
                node["database"],
                node["schema"],
                name,
                target_platform,
                environment,
            ),
            meta=node.get("meta", {}),
        )

        # overwrite columns from catalog
        if (dbtNode.materialization != "ephemeral" and
                load_catalog):  # we don't want columns if platform isn't 'dbt'
            logger.debug("Loading schema info")
            catalog_node = all_catalog_entities.get(key)

            if catalog_node is None:
                report.report_warning(
                    key,
                    f"Entity {dbtNode.dbt_name} is in manifest but missing from catalog",
                )
            else:
                dbtNode.columns = get_columns(catalog_node)

        else:
            dbtNode.columns = []

        dbt_entities.append(dbtNode)

    return dbt_entities
Exemple #16
0
    def from_api(  # noqa: C901
        cls,
        model: str,
        explore_name: str,
        client: Looker31SDK,
        reporter: SourceReport,
        transport_options: Optional[TransportOptions],
    ) -> Optional["LookerExplore"]:  # noqa: C901
        try:
            explore = client.lookml_model_explore(
                model, explore_name, transport_options=transport_options)
            views = set()

            if explore.view_name is not None and explore.view_name != explore.name:
                # explore is not named after a view and is instead using a from field, which is modeled as view_name.
                aliased_explore = True
                views.add(explore.view_name)
            else:
                # otherwise, the explore name is a view, so add it to the set.
                aliased_explore = False
                if explore.name is not None:
                    views.add(explore.name)

            if explore.joins is not None and explore.joins != []:
                potential_views = [
                    e.name for e in explore.joins if e.name is not None
                ]
                for e_join in [
                        e for e in explore.joins
                        if e.dependent_fields is not None
                ]:
                    assert e_join.dependent_fields is not None
                    for field_name in e_join.dependent_fields:
                        try:
                            view_name = LookerUtil._extract_view_from_field(
                                field_name)
                            potential_views.append(view_name)
                        except AssertionError:
                            reporter.report_warning(
                                key=f"chart-field-{field_name}",
                                reason=
                                "The field was not prefixed by a view name. This can happen when the field references another dynamic field.",
                            )
                            continue

                for view_name in potential_views:
                    if (view_name == explore.name) and aliased_explore:
                        # if the explore is aliased, then the joins could be referring to views using the aliased name.
                        # this needs to be corrected by switching to the actual view name of the explore
                        assert explore.view_name is not None
                        view_name = explore.view_name
                    views.add(view_name)

            view_fields: List[ViewField] = []
            if explore.fields is not None:
                if explore.fields.dimensions is not None:
                    for dim_field in explore.fields.dimensions:
                        if dim_field.name is None:
                            continue
                        else:
                            view_fields.append(
                                ViewField(
                                    name=dim_field.name,
                                    description=dim_field.description
                                    if dim_field.description else "",
                                    type=dim_field.type
                                    if dim_field.type is not None else "",
                                    field_type=ViewFieldType.DIMENSION_GROUP
                                    if dim_field.dimension_group is not None
                                    else ViewFieldType.DIMENSION,
                                    is_primary_key=dim_field.primary_key
                                    if dim_field.primary_key else False,
                                ))
                if explore.fields.measures is not None:
                    for measure_field in explore.fields.measures:
                        if measure_field.name is None:
                            continue
                        else:
                            view_fields.append(
                                ViewField(
                                    name=measure_field.name,
                                    description=measure_field.description
                                    if measure_field.description else "",
                                    type=measure_field.type
                                    if measure_field.type is not None else "",
                                    field_type=ViewFieldType.MEASURE,
                                    is_primary_key=measure_field.primary_key
                                    if measure_field.primary_key else False,
                                ))

            return cls(
                name=explore_name,
                model_name=model,
                project_name=explore.project_name,
                label=explore.label,
                description=explore.description,
                fields=view_fields,
                upstream_views=list(views),
                source_file=explore.source_file,
            )
        except SDKError as e:
            logger.warn("Failed to extract explore {} from model {}.".format(
                explore_name, model))
            logger.debug(
                "Failed to extract explore {} from model {} with {}".format(
                    explore_name, model, e))
        except AssertionError:
            reporter.report_warning(
                key="chart-",
                reason="Was unable to find dependent views for this chart",
            )
        return None