def test_lineage_query_duplication(sagemaker_session): lineage_query = LineageQuery(sagemaker_session) sagemaker_session.sagemaker_client.query_lineage.return_value = { "Vertices": [ {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, ], "Edges": [ {"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}, {"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}, ], } response = lineage_query.query( start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] ) assert len(response.edges) == 1 assert response.edges[0].source_arn == "arn1" assert response.edges[0].destination_arn == "arn2" assert response.edges[0].association_type == "Produced" assert len(response.vertices) == 2 assert response.vertices[0].arn == "arn1" assert response.vertices[0].lineage_source == "Endpoint" assert response.vertices[0].lineage_entity == "Artifact" assert response.vertices[1].arn == "arn2" assert response.vertices[1].lineage_source == "Model" assert response.vertices[1].lineage_entity == "Context"
def models_v2( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum. DESCENDANTS ) -> List[Artifact]: """Use the lineage query to retrieve downstream model artifacts that use this endpoint. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of Artifacts: Artifacts representing a model. """ # Firstly query out the model_deployment vertices query_filter = LineageFilter( entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]) model_deployment_query_result = LineageQuery( self.sagemaker_session).query( start_arns=[self.context_arn], query_filter=query_filter, direction=direction, include_edges=False, ) if not model_deployment_query_result: return [] model_deployment_vertices: [] = model_deployment_query_result.vertices # Secondary query model based on model deployment model_vertices = [] for vertex in model_deployment_vertices: query_result = LineageQuery(self.sagemaker_session).query( start_arns=[vertex.arn], query_filter=LineageFilter( entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]), direction=LineageQueryDirectionEnum.DESCENDANTS, include_edges=False, ) model_vertices.extend(query_result.vertices) if not model_vertices: return [] model_artifacts = [] for vertex in model_vertices: lineage_object = vertex.to_lineage_object() model_artifacts.append(lineage_object) return model_artifacts
def dataset_artifacts( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum. ASCENDANTS ) -> List[Artifact]: """Get artifacts representing datasets from the model's lineage. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of Artifacts: Artifacts representing a dataset. """ query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.artifact_arn], query_filter=query_filter, direction=direction, include_edges=False, ) dataset_artifacts = [] for vertex in query_result.vertices: dataset_artifacts.append(vertex.to_lineage_object()) return dataset_artifacts
def endpoint_contexts( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum. DESCENDANTS ) -> List[Context]: """Get contexts representing endpoints from the dataset's lineage. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of Contexts: Contexts representing an endpoint. """ query_filter = LineageFilter(entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.artifact_arn], query_filter=query_filter, direction=direction, include_edges=False, ) endpoint_contexts = [] for vertex in query_result.vertices: endpoint_contexts.append(vertex.to_lineage_object()) return endpoint_contexts
def static_training_job_trial_component( sagemaker_session, static_model_artifact) -> LineageTrialComponent: query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]) model_artifact_arn = static_model_artifact.artifact_arn query_result = LineageQuery(sagemaker_session).query( start_arns=[model_artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, ) logging.info( f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}" ) training_jobs = [] for vertex in query_result.vertices: training_jobs.append(vertex.to_lineage_object()) if not training_jobs: raise Exception( f"No training job found for static model artifact {model_artifact_arn}" ) return training_jobs[0]
def training_job_arns( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS ) -> List[str]: """Get ARNs for all training jobs that appear in the endpoint's lineage. Returns: list of str: Training job ARNs. """ query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB] ) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.context_arn], query_filter=query_filter, direction=direction, include_edges=False, ) training_job_arns = [] for vertex in query_result.vertices: trial_component_name = _utils.get_resource_name_from_arn(vertex.arn) trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component( TrialComponentName=trial_component_name ) training_job_arns.append(trial_component["Source"]["SourceArn"]) return training_job_arns
def test_lineage_query_cross_account_same_artifact(sagemaker_session): lineage_query = LineageQuery(sagemaker_session) sagemaker_session.sagemaker_client.query_lineage.return_value = { "Vertices": [ { "Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", "Type": "Endpoint", "LineageType": "Artifact", }, { "Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", "Type": "Endpoint", "LineageType": "Artifact", }, ], "Edges": [ { "SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", "AssociationType": "SAME_AS", }, { "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0", "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0", "AssociationType": "SAME_AS", }, ], } response = lineage_query.query( start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] ) assert len(response.edges) == 0 assert len(response.vertices) == 1 assert ( response.vertices[0].arn == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0" ) assert response.vertices[0].lineage_source == "Endpoint" assert response.vertices[0].lineage_entity == "Artifact"
def static_model_deployment_action(sagemaker_session, static_endpoint_context): query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.MODEL_DEPLOYMENT]) query_result = LineageQuery(sagemaker_session).query( start_arns=[static_endpoint_context.context_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, ) model_approval_actions = [] for vertex in query_result.vertices: model_approval_actions.append(vertex.to_lineage_object()) yield model_approval_actions[0]
def static_image_artifact(static_dataset_artifact, sagemaker_session): query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.IMAGE]) query_result = LineageQuery(sagemaker_session).query( start_arns=[static_dataset_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, ) image_artifact = [] for vertex in query_result.vertices: image_artifact.append(vertex.to_lineage_object()) return image_artifact[0]
def static_approval_action(sagemaker_session, static_endpoint_context, static_pipeline_execution_arn): query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION], sources=[LineageSourceEnum.APPROVAL]) query_result = LineageQuery(sagemaker_session).query( start_arns=[static_endpoint_context.context_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, ) action_name = query_result.vertices[0].arn.split("/")[1] yield action.ModelPackageApprovalAction.load( action_name=action_name, sagemaker_session=sagemaker_session)
def static_transform_job_trial_component( static_processing_job_trial_component, sagemaker_session, static_endpoint_context) -> LineageTrialComponent: query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRANSFORM_JOB]) query_result = LineageQuery(sagemaker_session).query( start_arns=[static_processing_job_trial_component.trial_component_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.DESCENDANTS, include_edges=False, ) transform_jobs = [] for vertex in query_result.vertices: transform_jobs.append(vertex.to_lineage_object()) yield transform_jobs[0]
def static_processing_job_trial_component( sagemaker_session, static_dataset_artifact) -> LineageTrialComponent: query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.PROCESSING_JOB]) query_result = LineageQuery(sagemaker_session).query( start_arns=[static_dataset_artifact.artifact_arn], query_filter=query_filter, direction=LineageQueryDirectionEnum.ASCENDANTS, include_edges=False, ) processing_jobs = [] for vertex in query_result.vertices: processing_jobs.append(vertex.to_lineage_object()) return processing_jobs[0]
def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]: """Use the lineage query to retrieve actions that use this context. Args: direction (LineageQueryDirectionEnum): The query direction. Returns: list of Actions: Actions. """ query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.context_arn], query_filter=query_filter, direction=direction, include_edges=False, ) return [vertex.to_lineage_object() for vertex in query_result.vertices]
def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]: """Use the lineage query to retrieve datasets that use this image artifact. Args: direction (LineageQueryDirectionEnum): The query direction. Returns: list of Artifacts: Artifacts representing a dataset. """ query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.artifact_arn], query_filter=query_filter, direction=direction, include_edges=False, ) return [vertex.to_lineage_object() for vertex in query_result.vertices]
def artifacts( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH ) -> List[Artifact]: """Use a lineage query to retrieve all artifacts that use this action. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of Artifacts: Artifacts. """ query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.action_arn], query_filter=query_filter, direction=direction, include_edges=False, ) return [vertex.to_lineage_object() for vertex in query_result.vertices]
def endpoints( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS ) -> List: """Use a lineage query to retrieve downstream endpoint contexts that use this action. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of Contexts: Contexts representing an endpoint. """ query_filter = LineageFilter( entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT] ) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.action_arn], query_filter=query_filter, direction=direction, include_edges=False, ) return [vertex.to_lineage_object() for vertex in query_result.vertices]
def models( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS ) -> List[Artifact]: """Use the lineage query to retrieve models that use this trial component. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of Artifacts: Artifacts representing a dataset. """ query_filter = LineageFilter( entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL] ) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.trial_component_arn], query_filter=query_filter, direction=direction, include_edges=False, ) return [vertex.to_lineage_object() for vertex in query_result.vertices]
def trial_components( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum. ASCENDANTS ) -> List[LineageTrialComponent]: """Use the lineage query to retrieve trial components that use this endpoint. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: list of LineageTrialComponent: Lineage trial component. """ query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.context_arn], query_filter=query_filter, direction=direction, include_edges=False, ) return [vertex.to_lineage_object() for vertex in query_result.vertices]
def _trials( self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH ) -> List: """Use the lineage query to retrieve all trials that use this artifact. Args: direction (LineageQueryDirectionEnum, optional): The query direction. Returns: [Trial]: A list of SageMaker `Trial` objects. """ query_filter = LineageFilter( entities=[LineageEntityEnum.TRIAL_COMPONENT]) query_result = LineageQuery(self.sagemaker_session).query( start_arns=[self.artifact_arn], query_filter=query_filter, direction=direction, include_edges=False, ) trial_component_arns: list = list( map(lambda x: x.arn, query_result.vertices)) return self._get_trial_from_trial_component(trial_component_arns)