Exemplo n.º 1
0
def test_downstream_trials(sagemaker_session):
    sagemaker_session.sagemaker_client.list_associations.side_effect = [
        {
            "AssociationSummaries": [
                {
                    "SourceArn": "A" + str(i),
                    "SourceName": "X" + str(i),
                    "DestinationArn": "B" + str(i),
                    "DestinationName": "Y" + str(i),
                    "SourceType": "C" + str(i),
                    "DestinationType": "D" + str(i),
                    "AssociationType": "E" + str(i),
                    "CreationTime": datetime.datetime.now(),
                    "CreatedBy": {},
                }
                for i in range(10)
            ],
            "NextToken": None,
        }
    ]

    sagemaker_session.sagemaker_client.search.return_value = {
        "Results": [
            {
                "TrialComponent": {
                    "TrialComponentName": "tc-1",
                    "TrialComponentArn": "arn::tc-1",
                    "DisplayName": "TC1",
                    "Parents": [{"TrialName": "test-trial-name"}],
                }
            }
        ]
    }

    obj = artifact.Artifact(
        sagemaker_session=sagemaker_session,
        artifact_arn="test-arn",
        artifact_name="foo",
        properties={"k1": "v1", "k2": "v2"},
        properties_to_remove=["r1"],
    )

    result = obj.downstream_trials(sagemaker_session=sagemaker_session)

    expected_trials = ["test-trial-name"]

    assert expected_trials == result

    expected_calls = [
        unittest.mock.call(
            SourceArn="test-arn",
        ),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
Exemplo n.º 2
0
def test_create_delete_with_association(sagemaker_session):
    obj = artifact.Artifact(sagemaker_session, artifact_arn="foo")

    sagemaker_session.sagemaker_client.list_associations.side_effect = [
        {
            "AssociationSummaries": [
                {
                    "SourceArn": obj.artifact_arn,
                    "SourceName": "X" + str(i),
                    "DestinationArn": "B" + str(i),
                    "DestinationName": "Y" + str(i),
                    "SourceType": "C" + str(i),
                    "DestinationType": "D" + str(i),
                    "AssociationType": "E" + str(i),
                    "CreationTime": None,
                    "CreatedBy": {},
                }
                for i in range(1)
            ],
        },
        {
            "AssociationSummaries": [
                {
                    "SourceArn": "A" + str(i),
                    "SourceName": "X" + str(i),
                    "DestinationArn": obj.artifact_arn,
                    "DestinationName": "Y" + str(i),
                    "SourceType": "C" + str(i),
                    "DestinationType": "D" + str(i),
                    "AssociationType": "E" + str(i),
                    "CreationTime": None,
                    "CreatedBy": {},
                }
                for i in range(1, 2)
            ]
        },
    ]
    sagemaker_session.sagemaker_client.delete_association.return_value = {}
    sagemaker_session.sagemaker_client.delete_artifact.return_value = {}

    obj.delete(disassociate=True)

    delete_with_association_expected_calls = [
        unittest.mock.call(SourceArn=obj.artifact_arn, DestinationArn="B0"),
        unittest.mock.call(SourceArn="A1", DestinationArn=obj.artifact_arn),
    ]
    assert (
        delete_with_association_expected_calls
        == sagemaker_session.sagemaker_client.delete_association.mock_calls
    )
Exemplo n.º 3
0
def test_s3_uri_artifacts(sagemaker_session):
    obj = artifact.Artifact(
        sagemaker_session=sagemaker_session,
        artifact_arn="test-arn",
        artifact_name="foo",
        source_uri="s3://abced",
        properties={"k1": "v1", "k2": "v2"},
        properties_to_remove=["r1"],
    )
    sagemaker_session.sagemaker_client.list_artifacts.side_effect = [
        {
            "ArtifactSummaries": [
                {
                    "ArtifactArn": "A",
                    "ArtifactName": "B",
                    "Source": {
                        "SourceUri": "D",
                        "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}],
                    },
                    "ArtifactType": "test-type",
                }
            ],
            "NextToken": "100",
        },
    ]
    result = obj.s3_uri_artifacts(s3_uri="s3://abced")

    expected_calls = [
        unittest.mock.call(SourceUri="s3://abced"),
    ]
    expected_result = {
        "ArtifactSummaries": [
            {
                "ArtifactArn": "A",
                "ArtifactName": "B",
                "Source": {
                    "SourceUri": "D",
                    "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}],
                },
                "ArtifactType": "test-type",
            }
        ],
        "NextToken": "100",
    }
    assert expected_calls == sagemaker_session.sagemaker_client.list_artifacts.mock_calls
    assert result == expected_result
Exemplo n.º 4
0
def test_save(sagemaker_session):
    obj = artifact.Artifact(
        sagemaker_session,
        artifact_arn="test-arn",
        artifact_name="foo",
        properties={"k1": "v1", "k2": "v2"},
        properties_to_remove=["r1"],
    )

    sagemaker_session.sagemaker_client.update_artifact.return_value = {}
    obj.save()

    sagemaker_session.sagemaker_client.update_artifact.assert_called_with(
        ArtifactArn="test-arn",
        ArtifactName="foo",
        Properties={"k1": "v1", "k2": "v2"},
        PropertiesToRemove=["r1"],
    )
Exemplo n.º 5
0
def test_upstream_trials(sagemaker_session):
    sagemaker_session.sagemaker_client.query_lineage.return_value = {
        "Vertices": [
            {"Arn": "B" + str(i), "Type": "DataSet", "LineageType": "Artifact"} for i in range(10)
        ],
        "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}],
    }
    sagemaker_session.sagemaker_client.search.return_value = {
        "Results": [
            {
                "TrialComponent": {
                    "TrialComponentName": "tc-1",
                    "TrialComponentArn": "arn::tc-1",
                    "DisplayName": "TC1",
                    "Parents": [{"TrialName": "test-trial-name"}],
                }
            }
        ]
    }

    obj = artifact.Artifact(
        sagemaker_session=sagemaker_session,
        artifact_arn="test-arn",
        artifact_name="foo",
        properties={"k1": "v1", "k2": "v2"},
        properties_to_remove=["r1"],
    )

    result = obj.upstream_trials()

    expected_trials = ["test-trial-name"]

    assert expected_trials == result

    expected_calls = [
        unittest.mock.call(
            Direction="Ascendants",
            Filters={"LineageTypes": ["TrialComponent"]},
            IncludeEdges=False,
            MaxDepth=10,
            StartArns=["test-arn"],
        ),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls
Exemplo n.º 6
0
def test_delete(sagemaker_session):
    obj = artifact.Artifact(sagemaker_session, artifact_arn="foo")
    sagemaker_session.sagemaker_client.delete_artifact.return_value = {}
    obj.delete()
    sagemaker_session.sagemaker_client.delete_artifact.assert_called_with(
        ArtifactArn="foo")