def test_downstream_trials(sagemaker_session): sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [ { "SourceArn": "A" + str(i), "SourceName": "X" + str(i), "DestinationArn": "B" + str(i), "DestinationName": "Y" + str(i), "SourceType": "C" + str(i), "DestinationType": "D" + str(i), "AssociationType": "E" + str(i), "CreationTime": datetime.datetime.now(), "CreatedBy": {}, } for i in range(10) ], "NextToken": None, } ] sagemaker_session.sagemaker_client.search.return_value = { "Results": [ { "TrialComponent": { "TrialComponentName": "tc-1", "TrialComponentArn": "arn::tc-1", "DisplayName": "TC1", "Parents": [{"TrialName": "test-trial-name"}], } } ] } obj = artifact.Artifact( sagemaker_session=sagemaker_session, artifact_arn="test-arn", artifact_name="foo", properties={"k1": "v1", "k2": "v2"}, properties_to_remove=["r1"], ) result = obj.downstream_trials(sagemaker_session=sagemaker_session) expected_trials = ["test-trial-name"] assert expected_trials == result expected_calls = [ unittest.mock.call( SourceArn="test-arn", ), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
def test_create_delete_with_association(sagemaker_session): obj = artifact.Artifact(sagemaker_session, artifact_arn="foo") sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [ { "SourceArn": obj.artifact_arn, "SourceName": "X" + str(i), "DestinationArn": "B" + str(i), "DestinationName": "Y" + str(i), "SourceType": "C" + str(i), "DestinationType": "D" + str(i), "AssociationType": "E" + str(i), "CreationTime": None, "CreatedBy": {}, } for i in range(1) ], }, { "AssociationSummaries": [ { "SourceArn": "A" + str(i), "SourceName": "X" + str(i), "DestinationArn": obj.artifact_arn, "DestinationName": "Y" + str(i), "SourceType": "C" + str(i), "DestinationType": "D" + str(i), "AssociationType": "E" + str(i), "CreationTime": None, "CreatedBy": {}, } for i in range(1, 2) ] }, ] sagemaker_session.sagemaker_client.delete_association.return_value = {} sagemaker_session.sagemaker_client.delete_artifact.return_value = {} obj.delete(disassociate=True) delete_with_association_expected_calls = [ unittest.mock.call(SourceArn=obj.artifact_arn, DestinationArn="B0"), unittest.mock.call(SourceArn="A1", DestinationArn=obj.artifact_arn), ] assert ( delete_with_association_expected_calls == sagemaker_session.sagemaker_client.delete_association.mock_calls )
def test_s3_uri_artifacts(sagemaker_session): obj = artifact.Artifact( sagemaker_session=sagemaker_session, artifact_arn="test-arn", artifact_name="foo", source_uri="s3://abced", properties={"k1": "v1", "k2": "v2"}, properties_to_remove=["r1"], ) sagemaker_session.sagemaker_client.list_artifacts.side_effect = [ { "ArtifactSummaries": [ { "ArtifactArn": "A", "ArtifactName": "B", "Source": { "SourceUri": "D", "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}], }, "ArtifactType": "test-type", } ], "NextToken": "100", }, ] result = obj.s3_uri_artifacts(s3_uri="s3://abced") expected_calls = [ unittest.mock.call(SourceUri="s3://abced"), ] expected_result = { "ArtifactSummaries": [ { "ArtifactArn": "A", "ArtifactName": "B", "Source": { "SourceUri": "D", "source_types": [{"SourceIdType": "source_id_type", "Value": "value1"}], }, "ArtifactType": "test-type", } ], "NextToken": "100", } assert expected_calls == sagemaker_session.sagemaker_client.list_artifacts.mock_calls assert result == expected_result
def test_save(sagemaker_session): obj = artifact.Artifact( sagemaker_session, artifact_arn="test-arn", artifact_name="foo", properties={"k1": "v1", "k2": "v2"}, properties_to_remove=["r1"], ) sagemaker_session.sagemaker_client.update_artifact.return_value = {} obj.save() sagemaker_session.sagemaker_client.update_artifact.assert_called_with( ArtifactArn="test-arn", ArtifactName="foo", Properties={"k1": "v1", "k2": "v2"}, PropertiesToRemove=["r1"], )
def test_upstream_trials(sagemaker_session): sagemaker_session.sagemaker_client.query_lineage.return_value = { "Vertices": [ {"Arn": "B" + str(i), "Type": "DataSet", "LineageType": "Artifact"} for i in range(10) ], "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], } sagemaker_session.sagemaker_client.search.return_value = { "Results": [ { "TrialComponent": { "TrialComponentName": "tc-1", "TrialComponentArn": "arn::tc-1", "DisplayName": "TC1", "Parents": [{"TrialName": "test-trial-name"}], } } ] } obj = artifact.Artifact( sagemaker_session=sagemaker_session, artifact_arn="test-arn", artifact_name="foo", properties={"k1": "v1", "k2": "v2"}, properties_to_remove=["r1"], ) result = obj.upstream_trials() expected_trials = ["test-trial-name"] assert expected_trials == result expected_calls = [ unittest.mock.call( Direction="Ascendants", Filters={"LineageTypes": ["TrialComponent"]}, IncludeEdges=False, MaxDepth=10, StartArns=["test-arn"], ), ] assert expected_calls == sagemaker_session.sagemaker_client.query_lineage.mock_calls
def test_delete(sagemaker_session): obj = artifact.Artifact(sagemaker_session, artifact_arn="foo") sagemaker_session.sagemaker_client.delete_artifact.return_value = {} obj.delete() sagemaker_session.sagemaker_client.delete_artifact.assert_called_with( ArtifactArn="foo")