예제 #1
0
    def process_labeling_job(self, job: Dict[str, Any]) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_labeling_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_labeling_job
        """

        JOB_TYPE = JobType.LABELING

        input_datasets = {}

        input_s3_uri: Optional[str] = (job.get("InputConfig", {}).get(
            "DataSource", {}).get("S3DataSource", {}).get("ManifestS3Uri"))
        if input_s3_uri is not None:
            input_datasets[make_s3_urn(input_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": input_s3_uri,
            }
        category_config_s3_uri: Optional[str] = job.get(
            "LabelCategoryConfigS3Uri")
        if category_config_s3_uri is not None:
            input_datasets[make_s3_urn(category_config_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": category_config_s3_uri,
            }

        output_datasets = {}

        output_s3_uri: Optional[str] = job.get("LabelingJobOutput",
                                               {}).get("OutputDatasetS3Uri")
        if output_s3_uri is not None:
            output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": output_s3_uri,
            }
        output_config_s3_uri: Optional[str] = job.get("OutputConfig",
                                                      {}).get("S3OutputPath")
        if output_config_s3_uri is not None:
            output_datasets[make_s3_urn(output_config_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": output_config_s3_uri,
            }

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
            f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/labeling-jobs/{job['LabelingJobName']}",
        )

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            input_datasets=input_datasets,
            output_datasets=output_datasets,
        )
예제 #2
0
    def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_auto_ml_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
        """

        JOB_TYPE = JobType.AUTO_ML

        input_datasets = {}

        for input_config in job.get("InputDataConfig", []):
            input_data = input_config.get("DataSource", {}).get("S3DataSource")

            if input_data is not None and "S3Uri" in input_data:
                input_datasets[make_s3_urn(input_data["S3Uri"], self.env)] = {
                    "dataset_type": "s3",
                    "uri": input_data["S3Uri"],
                    "datatype": input_data.get("S3DataType"),
                }

        output_datasets = {}

        output_s3_path = job.get("OutputDataConfig", {}).get("S3OutputPath")

        if output_s3_path is not None:
            output_datasets[make_s3_urn(output_s3_path, self.env)] = {
                "dataset_type": "s3",
                "uri": output_s3_path,
            }

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
        )

        model_containers = job.get("BestCandidate",
                                   {}).get("InferenceContainers", [])

        for model_container in model_containers:

            model_data_url = model_container.get("ModelDataUrl")

            if model_data_url is not None:

                job_key = JobKey(job_snapshot.urn, JobDirection.TRAINING)

                self.update_model_image_jobs(model_data_url, job_key)

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            input_datasets=input_datasets,
            output_datasets=output_datasets,
        )
예제 #3
0
    def process_compilation_job(self, job: Dict[str, Any]) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_compilation_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_compilation_job
        """

        JOB_TYPE = JobType.COMPILATION

        input_datasets = {}

        input_data: Optional[Dict[str, Any]] = job.get("InputConfig")

        if input_data is not None and "S3Uri" in input_data:
            input_datasets[make_s3_urn(input_data["S3Uri"], self.env)] = {
                "dataset_type": "s3",
                "uri": input_data["S3Uri"],
                "framework": input_data.get("Framework"),
                "framework_version": input_data.get("FrameworkVersion"),
            }

        output_datasets = {}

        output_data: Optional[Dict[str, Any]] = job.get("OutputConfig")

        if output_data is not None and "S3OutputLocation" in output_data:
            output_datasets[make_s3_urn(output_data["S3OutputLocation"],
                                        self.env)] = {
                                            "dataset_type": "s3",
                                            "uri":
                                            output_data["S3OutputLocation"],
                                            "target_device":
                                            output_data.get("TargetDevice"),
                                            "target_platform":
                                            output_data.get("TargetPlatform"),
                                        }

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
            f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/compilation-jobs/{job['CompilationJobName']}",
        )

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            input_datasets=input_datasets,
            output_datasets=output_datasets,
        )
예제 #4
0
 def get_lineage_if_enabled(
     self, mce: MetadataChangeEventClass
 ) -> Optional[MetadataChangeProposalWrapper]:
     if self.source_config.emit_s3_lineage:
         # extract dataset properties aspect
         dataset_properties: Optional[
             DatasetPropertiesClass] = mce_builder.get_aspect_if_available(
                 mce, DatasetPropertiesClass)
         if dataset_properties and "Location" in dataset_properties.customProperties:
             location = dataset_properties.customProperties["Location"]
             if location.startswith("s3://"):
                 s3_dataset_urn = make_s3_urn(location,
                                              self.source_config.env)
                 if self.source_config.glue_s3_lineage_direction == "upstream":
                     upstream_lineage = UpstreamLineageClass(upstreams=[
                         UpstreamClass(
                             dataset=s3_dataset_urn,
                             type=DatasetLineageTypeClass.COPY,
                         )
                     ])
                     mcp = MetadataChangeProposalWrapper(
                         entityType="dataset",
                         entityUrn=mce.proposedSnapshot.urn,
                         changeType=ChangeTypeClass.UPSERT,
                         aspectName="upstreamLineage",
                         aspect=upstream_lineage,
                     )
                     return mcp
                 else:
                     # Need to mint the s3 dataset with upstream lineage from it to glue
                     upstream_lineage = UpstreamLineageClass(upstreams=[
                         UpstreamClass(
                             dataset=mce.proposedSnapshot.urn,
                             type=DatasetLineageTypeClass.COPY,
                         )
                     ])
                     mcp = MetadataChangeProposalWrapper(
                         entityType="dataset",
                         entityUrn=s3_dataset_urn,
                         changeType=ChangeTypeClass.UPSERT,
                         aspectName="upstreamLineage",
                         aspect=upstream_lineage,
                     )
                     return mcp
     return None
예제 #5
0
    def get_table_properties(
        self, inspector: Inspector, schema: str, table: str
    ) -> Tuple[Optional[str], Optional[Dict[str, str]], Optional[str]]:
        if not self.cursor:
            self.cursor = inspector.dialect._raw_connection(
                inspector.engine).cursor()

        assert self.cursor
        # Unfortunately properties can be only get through private methods as those are not exposed
        # https://github.com/laughingman7743/PyAthena/blob/9e42752b0cc7145a87c3a743bb2634fe125adfa7/pyathena/model.py#L201
        metadata: AthenaTableMetadata = self.cursor._get_table_metadata(
            table_name=table, schema_name=schema)
        description = metadata.comment
        custom_properties: Dict[str, str] = {}
        custom_properties["partition_keys"] = json.dumps([{
            "name":
            partition.name,
            "type":
            partition.type,
            "comment":
            partition.comment if partition.comment else "",
        } for partition in metadata.partition_keys])
        for key, value in metadata.parameters.items():
            custom_properties[key] = value if value else ""

        custom_properties["create_time"] = (str(metadata.create_time)
                                            if metadata.create_time else "")
        custom_properties["last_access_time"] = (str(
            metadata.last_access_time) if metadata.last_access_time else "")
        custom_properties["table_type"] = (metadata.table_type
                                           if metadata.table_type else "")

        location: Optional[str] = custom_properties.get("location", None)
        if location is not None:
            if location.startswith("s3://"):
                location = make_s3_urn(location, self.config.env)
            else:
                logging.debug(
                    f"Only s3 url supported for location. Skipping {location}")
                location = None

        return description, custom_properties, location
예제 #6
0
    def _get_upstream_lineage_info(
            self, dataset_urn: str
    ) -> Optional[Tuple[UpstreamLineage, Dict[str, str]]]:
        dataset_key = builder.dataset_urn_to_key(dataset_urn)
        if dataset_key is None:
            logger.warning(
                f"Invalid dataset urn {dataset_urn}. Could not get key!")
            return None

        if self._lineage_map is None:
            self._populate_lineage()
            self._populate_view_lineage()
        if self._external_lineage_map is None:
            self._populate_external_lineage()

        assert self._lineage_map is not None
        assert self._external_lineage_map is not None
        dataset_name = dataset_key.name
        lineage = self._lineage_map[dataset_name]
        external_lineage = self._external_lineage_map[dataset_name]
        if not (lineage or external_lineage):
            logger.debug(f"No lineage found for {dataset_name}")
            return None
        upstream_tables: List[UpstreamClass] = []
        column_lineage: Dict[str, str] = {}
        for lineage_entry in lineage:
            # Update the table-lineage
            upstream_table_name = lineage_entry[0]
            if not self._is_dataset_allowed(upstream_table_name):
                continue
            upstream_table = UpstreamClass(
                dataset=builder.make_dataset_urn_with_platform_instance(
                    self.platform,
                    upstream_table_name,
                    self.config.platform_instance,
                    self.config.env,
                ),
                type=DatasetLineageTypeClass.TRANSFORMED,
            )
            upstream_tables.append(upstream_table)
            # Update column-lineage for each down-stream column.
            upstream_columns = [
                d["columnName"].lower() for d in json.loads(lineage_entry[1])
            ]
            downstream_columns = [
                d["columnName"].lower() for d in json.loads(lineage_entry[2])
            ]
            upstream_column_str = (
                f"{upstream_table_name}({', '.join(sorted(upstream_columns))})"
            )
            downstream_column_str = (
                f"{dataset_name}({', '.join(sorted(downstream_columns))})")
            column_lineage_key = f"column_lineage[{upstream_table_name}]"
            column_lineage_value = (
                f"{{{upstream_column_str} -> {downstream_column_str}}}")
            column_lineage[column_lineage_key] = column_lineage_value
            logger.debug(f"{column_lineage_key}:{column_lineage_value}")

        for external_lineage_entry in external_lineage:
            # For now, populate only for S3
            if external_lineage_entry.startswith("s3://"):
                external_upstream_table = UpstreamClass(
                    dataset=make_s3_urn(external_lineage_entry,
                                        self.config.env),
                    type=DatasetLineageTypeClass.COPY,
                )
                upstream_tables.append(external_upstream_table)

        if upstream_tables:
            logger.debug(
                f"Upstream lineage of '{dataset_name}': {[u.dataset for u in upstream_tables]}"
            )
            if self.config.report_upstream_lineage:
                self.report.upstream_lineage[dataset_name] = [
                    u.dataset for u in upstream_tables
                ]
            return UpstreamLineage(upstreams=upstream_tables), column_lineage
        return None
예제 #7
0
    def process_dataflow_node(
        self,
        node: Dict[str, Any],
        flow_urn: str,
        new_dataset_ids: List[str],
        new_dataset_mces: List[MetadataChangeEvent],
        s3_formats: typing.DefaultDict[str, Set[Union[str, None]]],
    ) -> Optional[Dict[str, Any]]:

        node_type = node["NodeType"]

        # for nodes representing datasets, we construct a dataset URN accordingly
        if node_type in ["DataSource", "DataSink"]:

            node_args = {
                x["Name"]: yaml.safe_load(x["Value"])
                for x in node["Args"]
            }

            # if data object is Glue table
            if "database" in node_args and "table_name" in node_args:

                full_table_name = f"{node_args['database']}.{node_args['table_name']}"

                # we know that the table will already be covered when ingesting Glue tables
                node_urn = make_dataset_urn_with_platform_instance(
                    platform=self.platform,
                    name=full_table_name,
                    env=self.env,
                    platform_instance=self.source_config.platform_instance,
                )

            # if data object is S3 bucket
            elif node_args.get("connection_type") == "s3":

                s3_uri = self.get_s3_uri(node_args)

                if s3_uri is None:
                    self.report.report_warning(
                        f"{node['Nodetype']}-{node['Id']}",
                        f"Could not find script path for job {node['Nodetype']}-{node['Id']} in flow {flow_urn}. Skipping",
                    )
                    return None

                # append S3 format if different ones exist
                if len(s3_formats[s3_uri]) > 1:
                    node_urn = make_s3_urn(
                        f"{s3_uri}.{node_args.get('format')}",
                        self.env,
                    )

                else:
                    node_urn = make_s3_urn(s3_uri, self.env)

                dataset_snapshot = DatasetSnapshot(
                    urn=node_urn,
                    aspects=[],
                )

                dataset_snapshot.aspects.append(Status(removed=False))
                dataset_snapshot.aspects.append(
                    DatasetPropertiesClass(
                        customProperties={
                            k: str(v)
                            for k, v in node_args.items()
                        },
                        tags=[],
                    ))

                new_dataset_mces.append(
                    MetadataChangeEvent(proposedSnapshot=dataset_snapshot))
                new_dataset_ids.append(f"{node['NodeType']}-{node['Id']}")

            else:

                if self.source_config.ignore_unsupported_connectors:

                    logger.info(
                        flow_urn,
                        f"Unrecognized Glue data object type: {node_args}. Skipping.",
                    )
                    return None
                else:

                    raise ValueError(
                        f"Unrecognized Glue data object type: {node_args}")

        # otherwise, a node represents a transformation
        else:
            node_urn = mce_builder.make_data_job_urn_with_flow(
                flow_urn, job_id=f'{node["NodeType"]}-{node["Id"]}')

        return {
            **node,
            "urn": node_urn,
            # to be filled in after traversing edges
            "inputDatajobs": [],
            "inputDatasets": [],
            "outputDatasets": [],
        }
예제 #8
0
    def process_transform_job(self, job: Dict[str, Any]) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_transform_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_transform_job
        """

        JOB_TYPE = JobType.TRANSFORM

        job_input = job.get("TransformInput", {})
        input_s3 = job_input.get("DataSource", {}).get("S3DataSource", {})

        input_s3_uri = input_s3.get("S3Uri")

        input_datasets = {}

        if input_s3_uri is not None:

            input_datasets[make_s3_urn(input_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": input_s3_uri,
                "datatype": input_s3.get("S3DataType"),
                "compression": job_input.get("CompressionType"),
                "split": job_input.get("SplitType"),
            }

        output_datasets = {}

        output_s3_uri = job.get("TransformOutput", {}).get("S3OutputPath")

        if output_s3_uri is not None:
            output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": output_s3_uri,
            }

        labeling_arn = job.get("LabelingJobArn")
        auto_ml_arn = job.get("AutoMLJobArn")

        input_jobs = set()

        if labeling_arn is not None:
            labeling_type, labeling_name = self.arn_to_name.get(
                labeling_arn, (None, None))

            if labeling_type is not None and labeling_name is not None:
                input_jobs.add(
                    make_sagemaker_job_urn(labeling_type, labeling_name,
                                           labeling_arn, self.env))

        if auto_ml_arn is not None:
            auto_ml_type, auto_ml_name = self.arn_to_name.get(
                auto_ml_arn, (None, None))

            if auto_ml_type is not None and auto_ml_name is not None:
                input_jobs.add(
                    make_sagemaker_job_urn(auto_ml_type, auto_ml_name,
                                           auto_ml_arn, self.env))

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
            f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/transform-jobs/{job['TransformJobName']}",
        )

        if job.get("ModelName") is not None:
            job_key = JobKey(job_snapshot.urn, JobDirection.DOWNSTREAM)

            self.update_model_name_jobs(
                job["ModelName"],
                job_key,
            )

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            input_datasets=input_datasets,
            output_datasets=output_datasets,
            input_jobs=input_jobs,
        )
예제 #9
0
    def process_training_job(self, job: Dict[str, Any]) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_training_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_training_job
        """

        JOB_TYPE = JobType.TRAINING

        input_datasets = {}

        input_data_configs = job.get("InputDataConfig", [])

        for config in input_data_configs:

            data_source = config.get("DataSource", {})

            s3_source = data_source.get("S3DataSource", {})
            s3_uri = s3_source.get("S3Uri")

            if s3_uri is not None:
                input_datasets[make_s3_urn(s3_uri, self.env)] = {
                    "dataset_type": "s3",
                    "uri": s3_uri,
                    "datatype": s3_source.get("S3Datatype"),
                    "distribution_type":
                    s3_source.get("S3DataDistributionType"),
                    "attribute_names": s3_source.get("AttributeNames"),
                    "channel_name": config.get("ChannelName"),
                }

        output_s3_uri = job.get("OutputDataConfig", {}).get("S3OutputPath")
        checkpoint_s3_uri = job.get("CheckpointConfig", {}).get("S3Uri")
        debug_s3_path = job.get("DebugHookConfig", {}).get("S3OutputPath")
        tensorboard_output_path = job.get("TensorBoardOutputConfig",
                                          {}).get("S3OutputPath")
        profiler_output_path = job.get("ProfilerConfig",
                                       {}).get("S3OutputPath")

        debug_rule_configs = job.get("DebugRuleConfigurations", [])
        processed_debug_configs = [
            config.get("S3OutputPath") for config in debug_rule_configs
        ]
        profiler_rule_configs = job.get("ProfilerRuleConfigurations", [])
        processed_profiler_configs = [
            config.get("S3OutputPath") for config in profiler_rule_configs
        ]

        output_datasets = {}

        # process all output datasets at once
        for output_s3_uri in [
                output_s3_uri,
                checkpoint_s3_uri,
                debug_s3_path,
                tensorboard_output_path,
                profiler_output_path,
                *processed_debug_configs,
                *processed_profiler_configs,
        ]:

            if output_s3_uri is not None:
                output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
                    "dataset_type": "s3",
                    "uri": output_s3_uri,
                }

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
            f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/jobs/{job['TrainingJobName']}",
        )

        model_data_url = job.get("ModelArtifacts", {}).get("S3ModelArtifacts")

        job_metrics = job.get("FinalMetricDataList", [])
        # sort first by metric name, then from latest -> earliest
        sorted_metrics = sorted(job_metrics,
                                key=lambda x:
                                (x["MetricName"], x["Timestamp"]),
                                reverse=True)
        # extract the last recorded metric values
        latest_metrics = []
        seen_keys = set()
        for metric in sorted_metrics:
            if metric["MetricName"] not in seen_keys:
                latest_metrics.append(metric)
                seen_keys.add(metric["MetricName"])

        metrics = dict(
            zip(
                [metric["MetricName"] for metric in latest_metrics],
                [metric["Value"] for metric in latest_metrics],
            ))

        if model_data_url is not None:

            job_key = JobKey(job_snapshot.urn, JobDirection.TRAINING)

            self.update_model_image_jobs(
                model_data_url,
                job_key,
                metrics,
                job.get("HyperParameters", {}),
            )

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            input_datasets=input_datasets,
            output_datasets=output_datasets,
        )
예제 #10
0
    def process_processing_job(self, job: Dict[str, Any]) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_processing_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_processing_job
        """

        JOB_TYPE = JobType.PROCESSING

        input_jobs = set()

        auto_ml_arn: Optional[str] = job.get("AutoMLJobArn")
        training_arn: Optional[str] = job.get("TrainingJobArn")

        if auto_ml_arn is not None:
            auto_ml_type, auto_ml_name = self.arn_to_name.get(
                auto_ml_arn, (None, None))

            if auto_ml_type is not None and auto_ml_name is not None:
                input_jobs.add(
                    make_sagemaker_job_urn(auto_ml_type, auto_ml_name,
                                           auto_ml_arn, self.env))

        if training_arn is not None:
            training_type, training_name = self.arn_to_name.get(
                training_arn, (None, None))
            if training_type is not None and training_name is not None:
                input_jobs.add(
                    make_sagemaker_job_urn(training_type, training_name,
                                           training_arn, self.env))

        input_datasets = {}

        inputs = job["ProcessingInputs"]

        for input_config in inputs:

            input_name = input_config["InputName"]

            input_s3 = input_config.get("S3Input", {})
            input_s3_uri = input_s3.get("S3Uri")

            if input_s3_uri is not None:

                input_datasets[make_s3_urn(input_s3_uri, self.env)] = {
                    "dataset_type": "s3",
                    "uri": input_s3_uri,
                    "datatype": input_s3.get("S3DataType"),
                    "mode": input_s3.get("S3InputMode"),
                    "distribution_type":
                    input_s3.get("S3DataDistributionType"),
                    "compression": input_s3.get("S3CompressionType"),
                    "name": input_name,
                }

            # TODO: ingest Athena and Redshift data sources
            # We don't do this at the moment because we need to parse the QueryString SQL
            # in order to get the tables used (otherwise we just have databases)

            # input_athena = input_config.get("DatasetDefinition", {}).get(
            #     "AthenaDatasetDefinition", {}
            # )

            # input_redshift = input_config.get("DatasetDefinition", {}).get(
            #     "RedshiftDatasetDefinition", {}
            # )

        outputs: List[Dict[str, Any]] = job.get("ProcessingOutputConfig",
                                                {}).get("Outputs", [])

        output_datasets = {}

        for output in outputs:
            output_name = output["OutputName"]

            output_s3_uri = output.get("S3Output", {}).get("S3Uri")
            if output_s3_uri is not None:
                output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
                    "dataset_type": "s3",
                    "uri": output_s3_uri,
                    "name": output_name,
                }

            output_feature_group = output.get("FeatureStoreOutput",
                                              {}).get("FeatureGroupName")
            if output_feature_group is not None:
                output_datasets[mce_builder.make_ml_feature_table_urn(
                    "sagemaker", output_feature_group)] = {
                        "dataset_type": "sagemaker_feature_group",
                    }

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
            f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/processing-jobs/{job['ProcessingJobName']}",
        )

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            input_datasets=input_datasets,
            input_jobs=input_jobs,
        )
예제 #11
0
    def process_edge_packaging_job(
        self,
        job: Dict[str, Any],
    ) -> SageMakerJob:
        """
        Process outputs from Boto3 describe_edge_packaging_job()

        See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
        """

        JOB_TYPE = JobType.EDGE_PACKAGING

        name: str = job["EdgePackagingJobName"]
        arn: str = job["EdgePackagingJobArn"]

        output_datasets = {}

        model_artifact_s3_uri: Optional[str] = job.get("ModelArtifact")
        output_s3_uri: Optional[str] = job.get("OutputConfig",
                                               {}).get("S3OutputLocation")

        if model_artifact_s3_uri is not None:
            output_datasets[make_s3_urn(model_artifact_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": model_artifact_s3_uri,
            }

        if output_s3_uri is not None:
            output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
                "dataset_type": "s3",
                "uri": output_s3_uri,
            }

        # from docs: "The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged."
        compilation_job_name: Optional[str] = job.get("CompilationJobName")

        output_jobs = set()
        if compilation_job_name is not None:

            # globally unique job name
            full_job_name = ("compilation", compilation_job_name)

            if full_job_name in self.name_to_arn:

                output_jobs.add(
                    make_sagemaker_job_urn(
                        "compilation",
                        compilation_job_name,
                        self.name_to_arn[full_job_name],
                        self.env,
                    ))
            else:

                self.report.report_warning(
                    name,
                    f"Unable to find ARN for compilation job {compilation_job_name} produced by edge packaging job {arn}",
                )

        job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
            job,
            JOB_TYPE,
            f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/edge-packaging-jobs/{job['EdgePackagingJobName']}",
        )

        if job.get("ModelName") is not None:

            job_key = JobKey(job_snapshot.urn, JobDirection.DOWNSTREAM)

            self.update_model_name_jobs(job["ModelName"], job_key)

        return SageMakerJob(
            job_name=job_name,
            job_arn=job_arn,
            job_type=JOB_TYPE,
            job_snapshot=job_snapshot,
            output_datasets=output_datasets,
            output_jobs=output_jobs,
        )