def test_lineage_query_duplication(sagemaker_session):
    lineage_query = LineageQuery(sagemaker_session)
    sagemaker_session.sagemaker_client.query_lineage.return_value = {
        "Vertices": [
            {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
            {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"},
            {"Arn": "arn2", "Type": "Model", "LineageType": "Context"},
        ],
        "Edges": [
            {"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
            {"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"},
        ],
    }

    response = lineage_query.query(
        start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
    )

    assert len(response.edges) == 1
    assert response.edges[0].source_arn == "arn1"
    assert response.edges[0].destination_arn == "arn2"
    assert response.edges[0].association_type == "Produced"
    assert len(response.vertices) == 2
    assert response.vertices[0].arn == "arn1"
    assert response.vertices[0].lineage_source == "Endpoint"
    assert response.vertices[0].lineage_entity == "Artifact"
    assert response.vertices[1].arn == "arn2"
    assert response.vertices[1].lineage_source == "Model"
    assert response.vertices[1].lineage_entity == "Context"
    def models_v2(
        self,
        direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.
        DESCENDANTS
    ) -> List[Artifact]:
        """Use the lineage query to retrieve downstream model artifacts that use this endpoint.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Artifacts: Artifacts representing a model.
        """
        # Firstly query out the model_deployment vertices
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.ACTION],
            sources=[LineageSourceEnum.MODEL_DEPLOYMENT])
        model_deployment_query_result = LineageQuery(
            self.sagemaker_session).query(
                start_arns=[self.context_arn],
                query_filter=query_filter,
                direction=direction,
                include_edges=False,
            )
        if not model_deployment_query_result:
            return []

        model_deployment_vertices: [] = model_deployment_query_result.vertices

        # Secondary query model based on model deployment
        model_vertices = []
        for vertex in model_deployment_vertices:
            query_result = LineageQuery(self.sagemaker_session).query(
                start_arns=[vertex.arn],
                query_filter=LineageFilter(
                    entities=[LineageEntityEnum.ARTIFACT],
                    sources=[LineageSourceEnum.MODEL]),
                direction=LineageQueryDirectionEnum.DESCENDANTS,
                include_edges=False,
            )
            model_vertices.extend(query_result.vertices)

        if not model_vertices:
            return []

        model_artifacts = []
        for vertex in model_vertices:
            lineage_object = vertex.to_lineage_object()
            model_artifacts.append(lineage_object)

        return model_artifacts
    def dataset_artifacts(
        self,
        direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.
        ASCENDANTS
    ) -> List[Artifact]:
        """Get artifacts representing datasets from the model's lineage.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Artifacts: Artifacts representing a dataset.
        """
        query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT],
                                     sources=[LineageSourceEnum.DATASET])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.artifact_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )

        dataset_artifacts = []
        for vertex in query_result.vertices:
            dataset_artifacts.append(vertex.to_lineage_object())
        return dataset_artifacts
    def endpoint_contexts(
        self,
        direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.
        DESCENDANTS
    ) -> List[Context]:
        """Get contexts representing endpoints from the dataset's lineage.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Contexts: Contexts representing an endpoint.
        """
        query_filter = LineageFilter(entities=[LineageEntityEnum.CONTEXT],
                                     sources=[LineageSourceEnum.ENDPOINT])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.artifact_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )

        endpoint_contexts = []
        for vertex in query_result.vertices:
            endpoint_contexts.append(vertex.to_lineage_object())
        return endpoint_contexts
Beispiel #5
0
def static_training_job_trial_component(
        sagemaker_session, static_model_artifact) -> LineageTrialComponent:
    query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT],
                                 sources=[LineageSourceEnum.TRAINING_JOB])

    model_artifact_arn = static_model_artifact.artifact_arn
    query_result = LineageQuery(sagemaker_session).query(
        start_arns=[model_artifact_arn],
        query_filter=query_filter,
        direction=LineageQueryDirectionEnum.ASCENDANTS,
        include_edges=False,
    )
    logging.info(
        f"Found {len(query_result.vertices)} trial components from model artifact {model_artifact_arn}"
    )
    training_jobs = []
    for vertex in query_result.vertices:
        training_jobs.append(vertex.to_lineage_object())

    if not training_jobs:
        raise Exception(
            f"No training job found for static model artifact {model_artifact_arn}"
        )

    return training_jobs[0]
    def training_job_arns(
        self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.ASCENDANTS
    ) -> List[str]:
        """Get ARNs for all training jobs that appear in the endpoint's lineage.

        Returns:
            list of str: Training job ARNs.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.TRIAL_COMPONENT], sources=[LineageSourceEnum.TRAINING_JOB]
        )
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.context_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )

        training_job_arns = []
        for vertex in query_result.vertices:
            trial_component_name = _utils.get_resource_name_from_arn(vertex.arn)
            trial_component = self.sagemaker_session.sagemaker_client.describe_trial_component(
                TrialComponentName=trial_component_name
            )
            training_job_arns.append(trial_component["Source"]["SourceArn"])
        return training_job_arns
def test_lineage_query_cross_account_same_artifact(sagemaker_session):
    lineage_query = LineageQuery(sagemaker_session)
    sagemaker_session.sagemaker_client.query_lineage.return_value = {
        "Vertices": [
            {
                "Arn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
                "Type": "Endpoint",
                "LineageType": "Artifact",
            },
            {
                "Arn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
                "Type": "Endpoint",
                "LineageType": "Artifact",
            },
        ],
        "Edges": [
            {
                "SourceArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
                "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
                "AssociationType": "SAME_AS",
            },
            {
                "SourceArn": "arn:aws:sagemaker:us-east-2:012345678902:artifact/e1f29799189751939405b0f2b5b9d2a0",
                "DestinationArn": "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0",
                "AssociationType": "SAME_AS",
            },
        ],
    }

    response = lineage_query.query(
        start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"]
    )
    assert len(response.edges) == 0
    assert len(response.vertices) == 1
    assert (
        response.vertices[0].arn
        == "arn:aws:sagemaker:us-east-2:012345678901:artifact/e1f29799189751939405b0f2b5b9d2a0"
    )
    assert response.vertices[0].lineage_source == "Endpoint"
    assert response.vertices[0].lineage_entity == "Artifact"
def static_model_deployment_action(sagemaker_session, static_endpoint_context):
    query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION],
                                 sources=[LineageSourceEnum.MODEL_DEPLOYMENT])
    query_result = LineageQuery(sagemaker_session).query(
        start_arns=[static_endpoint_context.context_arn],
        query_filter=query_filter,
        direction=LineageQueryDirectionEnum.ASCENDANTS,
        include_edges=False,
    )
    model_approval_actions = []
    for vertex in query_result.vertices:
        model_approval_actions.append(vertex.to_lineage_object())
    yield model_approval_actions[0]
Beispiel #9
0
def static_image_artifact(static_dataset_artifact, sagemaker_session):
    query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT],
                                 sources=[LineageSourceEnum.IMAGE])
    query_result = LineageQuery(sagemaker_session).query(
        start_arns=[static_dataset_artifact.artifact_arn],
        query_filter=query_filter,
        direction=LineageQueryDirectionEnum.ASCENDANTS,
        include_edges=False,
    )
    image_artifact = []
    for vertex in query_result.vertices:
        image_artifact.append(vertex.to_lineage_object())
    return image_artifact[0]
Beispiel #10
0
def static_approval_action(sagemaker_session, static_endpoint_context,
                           static_pipeline_execution_arn):
    query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION],
                                 sources=[LineageSourceEnum.APPROVAL])
    query_result = LineageQuery(sagemaker_session).query(
        start_arns=[static_endpoint_context.context_arn],
        query_filter=query_filter,
        direction=LineageQueryDirectionEnum.ASCENDANTS,
        include_edges=False,
    )
    action_name = query_result.vertices[0].arn.split("/")[1]
    yield action.ModelPackageApprovalAction.load(
        action_name=action_name, sagemaker_session=sagemaker_session)
Beispiel #11
0
def static_transform_job_trial_component(
        static_processing_job_trial_component, sagemaker_session,
        static_endpoint_context) -> LineageTrialComponent:
    query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT],
                                 sources=[LineageSourceEnum.TRANSFORM_JOB])
    query_result = LineageQuery(sagemaker_session).query(
        start_arns=[static_processing_job_trial_component.trial_component_arn],
        query_filter=query_filter,
        direction=LineageQueryDirectionEnum.DESCENDANTS,
        include_edges=False,
    )
    transform_jobs = []
    for vertex in query_result.vertices:
        transform_jobs.append(vertex.to_lineage_object())
    yield transform_jobs[0]
Beispiel #12
0
def static_processing_job_trial_component(
        sagemaker_session, static_dataset_artifact) -> LineageTrialComponent:
    query_filter = LineageFilter(entities=[LineageEntityEnum.TRIAL_COMPONENT],
                                 sources=[LineageSourceEnum.PROCESSING_JOB])

    query_result = LineageQuery(sagemaker_session).query(
        start_arns=[static_dataset_artifact.artifact_arn],
        query_filter=query_filter,
        direction=LineageQueryDirectionEnum.ASCENDANTS,
        include_edges=False,
    )
    processing_jobs = []
    for vertex in query_result.vertices:
        processing_jobs.append(vertex.to_lineage_object())

    return processing_jobs[0]
Beispiel #13
0
    def actions(self, direction: LineageQueryDirectionEnum) -> List[Action]:
        """Use the lineage query to retrieve actions that use this context.

        Args:
            direction (LineageQueryDirectionEnum): The query direction.

        Returns:
            list of Actions: Actions.
        """
        query_filter = LineageFilter(entities=[LineageEntityEnum.ACTION])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.context_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
    def datasets(self, direction: LineageQueryDirectionEnum) -> List[Artifact]:
        """Use the lineage query to retrieve datasets that use this image artifact.

        Args:
            direction (LineageQueryDirectionEnum): The query direction.

        Returns:
            list of Artifacts: Artifacts representing a dataset.
        """
        query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT],
                                     sources=[LineageSourceEnum.DATASET])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.artifact_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
Beispiel #15
0
    def artifacts(
        self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
    ) -> List[Artifact]:
        """Use a lineage query to retrieve all artifacts that use this action.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Artifacts: Artifacts.
        """
        query_filter = LineageFilter(entities=[LineageEntityEnum.ARTIFACT])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.action_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
Beispiel #16
0
    def endpoints(
        self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
    ) -> List:
        """Use a lineage query to retrieve downstream endpoint contexts that use this action.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Contexts: Contexts representing an endpoint.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.CONTEXT], sources=[LineageSourceEnum.ENDPOINT]
        )
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.action_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
    def models(
        self, direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.DESCENDANTS
    ) -> List[Artifact]:
        """Use the lineage query to retrieve models that use this trial component.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of Artifacts: Artifacts representing a dataset.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]
        )
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.trial_component_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
Beispiel #18
0
    def trial_components(
        self,
        direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.
        ASCENDANTS
    ) -> List[LineageTrialComponent]:
        """Use the lineage query to retrieve trial components that use this endpoint.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            list of LineageTrialComponent: Lineage trial component.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.TRIAL_COMPONENT])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.context_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        return [vertex.to_lineage_object() for vertex in query_result.vertices]
    def _trials(
        self,
        direction: LineageQueryDirectionEnum = LineageQueryDirectionEnum.BOTH
    ) -> List:
        """Use the lineage query to retrieve all trials that use this artifact.

        Args:
            direction (LineageQueryDirectionEnum, optional): The query direction.

        Returns:
            [Trial]: A list of SageMaker `Trial` objects.
        """
        query_filter = LineageFilter(
            entities=[LineageEntityEnum.TRIAL_COMPONENT])
        query_result = LineageQuery(self.sagemaker_session).query(
            start_arns=[self.artifact_arn],
            query_filter=query_filter,
            direction=direction,
            include_edges=False,
        )
        trial_component_arns: list = list(
            map(lambda x: x.arn, query_result.vertices))
        return self._get_trial_from_trial_component(trial_component_arns)