def test_trained_models(sagemaker_session): # dataset artifact ---- tc --- artifact model dataset_artifact_obj = artifact.DatasetArtifact( sagemaker_session, artifact_arn="dataset-artifact-arn", artifact_name="dataset-artifact-name", ) sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [ { "SourceArn": dataset_artifact_obj.artifact_arn, "SourceName": "X1", "DestinationArn": "experiment-trial-component", "DestinationName": "Y1", "SourceType": "C1", "DestinationType": "ModelDeployment", "AssociationType": "E1", "CreationTime": None, "CreatedBy": None, } ], }, { "AssociationSummaries": [ { "SourceArn": "experiment-trial-component", "SourceName": "X2", "DestinationArn": "B2", "DestinationName": "Y2", "SourceType": "C2", "DestinationType": "Context", "AssociationType": "E2", "CreationTime": None, "CreatedBy": None, } ] }, ] model_list = dataset_artifact_obj.trained_models() expected_calls = [ unittest.mock.call(SourceArn=dataset_artifact_obj.artifact_arn), unittest.mock.call(SourceArn="experiment-trial-component", DestinationType="Context"), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls expected_model_list = [ _api_types.AssociationSummary( source_arn="experiment-trial-component", source_name="X2", destination_arn="B2", destination_name="Y2", source_type="C2", destination_type="Context", association_type="E2", creation_time=None, created_by=None, ) ] assert expected_model_list == model_list
def test_models(sagemaker_session): obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz") sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [ { "SourceArn": "bazz", "SourceName": "X1", "DestinationArn": "B0", "DestinationName": "Y1", "SourceType": "C1", "DestinationType": "ModelDeployment", "AssociationType": "E1", "CreationTime": None, "CreatedBy": {}, } ], }, { "AssociationSummaries": [ { "SourceArn": "B0", "SourceName": "X2", "DestinationArn": "B2", "DestinationName": "Y2", "SourceType": "C2", "DestinationType": "Model", "AssociationType": "E2", "CreationTime": None, "CreatedBy": {}, } ] }, ] model_list = obj.models() expected_calls = [ unittest.mock.call(SourceArn=obj.context_arn, DestinationType="ModelDeployment"), unittest.mock.call(SourceArn="B0", DestinationType="Model"), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls expected_model_list = [ _api_types.AssociationSummary( source_arn="B0", source_name="X2", destination_arn="B2", destination_name="Y2", source_type="C2", destination_type="Model", association_type="E2", creation_time=None, created_by={}, ) ] assert expected_model_list == model_list
def test_trained_models(sagemaker_session): model_artifact_obj = artifact.ModelArtifact( sagemaker_session, artifact_arn="model-artifact-arn") sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [{ "SourceArn": model_artifact_obj.artifact_arn, "SourceName": "X1", "DestinationArn": "action-arn", "DestinationName": "Y1", "SourceType": "C1", "DestinationType": "Action", "AssociationType": "E1", "CreationTime": None, "CreatedBy": {}, }], }, { "AssociationSummaries": [{ "SourceArn": "action-arn", "SourceName": "X2", "DestinationArn": "endpoint-context-arn", "DestinationName": "Y2", "SourceType": "Action", "DestinationType": "Context", "AssociationType": "E2", "CreationTime": None, "CreatedBy": {}, }] }, ] endpoint_context_list = model_artifact_obj.endpoints() expected_calls = [ unittest.mock.call(SourceArn=model_artifact_obj.artifact_arn, DestinationType="Action"), unittest.mock.call(SourceArn="action-arn", DestinationType="Context"), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls expected_model_list = [ _api_types.AssociationSummary( source_arn="action-arn", source_name="X2", destination_arn="endpoint-context-arn", destination_name="Y2", source_type="Action", destination_type="Context", association_type="E2", creation_time=None, created_by={}, ) ] assert expected_model_list == endpoint_context_list
def test_list(sagemaker_session): creation_time = datetime.datetime.now( datetime.timezone.utc) + datetime.timedelta(hours=3) sagemaker_session.sagemaker_client.list_associations.side_effect = [ { "AssociationSummaries": [{ "SourceArn": "A" + str(i), "SourceName": "X" + str(i), "DestinationArn": "B" + str(i), "DestinationName": "Y" + str(i), "SourceType": "C" + str(i), "DestinationType": "D" + str(i), "AssociationType": "E" + str(i), "CreationTime": creation_time + datetime.timedelta(hours=i), "CreatedBy": {}, } for i in range(10)], "NextToken": "100", }, { "AssociationSummaries": [{ "SourceArn": "A" + str(i), "SourceName": "X" + str(i), "DestinationArn": "B" + str(i), "DestinationName": "Y" + str(i), "SourceType": "C" + str(i), "DestinationType": "D" + str(i), "AssociationType": "E" + str(i), "CreationTime": creation_time + datetime.timedelta(hours=i), "CreatedBy": {}, } for i in range(10, 20)] }, ] expected = [ _api_types.AssociationSummary( source_arn="A" + str(i), source_name="X" + str(i), destination_arn="B" + str(i), destination_name="Y" + str(i), source_type="C" + str(i), destination_type="D" + str(i), association_type="E" + str(i), creation_time=creation_time + datetime.timedelta(hours=i), created_by={}, ) for i in range(20) ] result = list( association.Association.list( sagemaker_session=sagemaker_session, source_arn="foo", destination_arn="bar", sort_by="CreationTime", sort_order="Ascending", )) assert expected == result expected_calls = [ unittest.mock.call( SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo", DestinationArn="bar", ), unittest.mock.call( NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo", DestinationArn="bar", ), ] assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls