def test_dataset_artifacts(sagemaker_session):
    trial_component_arn = (
        "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
    )
    artifact_dataset_arn = "arn:aws:sagemaker:us-west-2:123456789012:artifact/datasets"
    artifact_dataset_name = "myDataset"

    obj = lineage_trial_component.LineageTrialComponent(
        sagemaker_session,
        trial_component_name="foo",
        trial_component_arn=trial_component_arn)

    sagemaker_session.sagemaker_client.query_lineage.return_value = {
        "Vertices": [
            {
                "Arn": artifact_dataset_arn,
                "Type": "DataSet",
                "LineageType": "Artifact"
            },
        ],
        "Edges": [{
            "SourceArn": "arn1",
            "DestinationArn": "arn2",
            "AssociationType": "Produced"
        }],
    }
    sagemaker_session.sagemaker_client.describe_artifact.return_value = {
        "ArtifactName": artifact_dataset_name,
        "ArtifactArn": artifact_dataset_arn,
    }

    dataset_list = obj.dataset_artifacts()
    expected_calls = [
        unittest.mock.call(
            Direction="Ascendants",
            Filters={
                "Types": ["DataSet"],
                "LineageTypes": ["Artifact"]
            },
            IncludeEdges=False,
            MaxDepth=10,
            StartArns=[trial_component_arn],
        ),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls
    expected_dataset_list = [
        artifact.DatasetArtifact(
            artifact_name=artifact_dataset_name,
            artifact_arn=artifact_dataset_arn,
        )
    ]
    assert expected_dataset_list[0].artifact_arn == dataset_list[
        0].artifact_arn
    assert expected_dataset_list[0].artifact_name == dataset_list[
        0].artifact_name
def test_no_pipeline_execution_arn(sagemaker_session):
    trial_component_arn = (
        "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
    )
    training_job_arn = (
        "arn:aws:sagemaker:us-west-2:123456789012:training-job/pipelines-bs6gaeln463r-abalonetrain"
    )
    context = lineage_trial_component.LineageTrialComponent(
        sagemaker_session,
        trial_component_name="foo",
        trial_component_arn=trial_component_arn,
        source={
            "SourceArn": training_job_arn,
            "SourceType": "SageMakerTrainingJob",
        },
    )
    obj = {
        "TrialComponentName":
        "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
        "TrialComponentArn": trial_component_arn,
        "DisplayName":
        "pipelines-bs6gaeln463r-AbaloneTrain-A0QiDGuY6z-aws-training-job",
        "Source": {
            "SourceArn": training_job_arn,
            "SourceType": "SageMakerTrainingJob",
        },
    }
    sagemaker_session.sagemaker_client.describe_trial_component.return_value = obj

    sagemaker_session.sagemaker_client.list_tags.return_value = {
        "Tags": [
            {
                "Key": "abcd",
                "Value": "efg"
            },
        ],
    }
    expected_calls = [
        unittest.mock.call(ResourceArn=training_job_arn),
    ]
    pipeline_execution_arn_result = context.pipeline_execution_arn()
    expected_result = None
    assert pipeline_execution_arn_result == expected_result
    assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls
Ejemplo n.º 3
0
def test_pipeline_execution_arn(sagemaker_session):
    trial_component_arn = (
        "arn:aws:sagemaker:us-west-2:123456789012:trial_component/lineage-unit-3b05f017-0d87-4c37"
    )
    obj = lineage_trial_component.LineageTrialComponent(
        sagemaker_session,
        trial_component_name="foo",
        trial_component_arn=trial_component_arn)

    sagemaker_session.sagemaker_client.list_tags.return_value = {
        "Tags": [
            {
                "Key": "sagemaker:pipeline-execution-arn",
                "Value": "tag1"
            },
        ],
    }
    expected_calls = [
        unittest.mock.call(ResourceArn=trial_component_arn),
    ]
    pipeline_execution_arn_result = obj.pipeline_execution_arn()
    assert pipeline_execution_arn_result == "tag1"
    assert expected_calls == sagemaker_session.sagemaker_client.list_tags.mock_calls