Example #1
0
def test_trained_models(sagemaker_session):
    # dataset artifact ---- tc --- artifact model
    dataset_artifact_obj = artifact.DatasetArtifact(
        sagemaker_session,
        artifact_arn="dataset-artifact-arn",
        artifact_name="dataset-artifact-name",
    )
    sagemaker_session.sagemaker_client.list_associations.side_effect = [
        {
            "AssociationSummaries": [
                {
                    "SourceArn": dataset_artifact_obj.artifact_arn,
                    "SourceName": "X1",
                    "DestinationArn": "experiment-trial-component",
                    "DestinationName": "Y1",
                    "SourceType": "C1",
                    "DestinationType": "ModelDeployment",
                    "AssociationType": "E1",
                    "CreationTime": None,
                    "CreatedBy": None,
                }
            ],
        },
        {
            "AssociationSummaries": [
                {
                    "SourceArn": "experiment-trial-component",
                    "SourceName": "X2",
                    "DestinationArn": "B2",
                    "DestinationName": "Y2",
                    "SourceType": "C2",
                    "DestinationType": "Context",
                    "AssociationType": "E2",
                    "CreationTime": None,
                    "CreatedBy": None,
                }
            ]
        },
    ]

    model_list = dataset_artifact_obj.trained_models()
    expected_calls = [
        unittest.mock.call(SourceArn=dataset_artifact_obj.artifact_arn),
        unittest.mock.call(SourceArn="experiment-trial-component", DestinationType="Context"),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
    expected_model_list = [
        _api_types.AssociationSummary(
            source_arn="experiment-trial-component",
            source_name="X2",
            destination_arn="B2",
            destination_name="Y2",
            source_type="C2",
            destination_type="Context",
            association_type="E2",
            creation_time=None,
            created_by=None,
        )
    ]
    assert expected_model_list == model_list
def test_models(sagemaker_session):
    obj = context.EndpointContext(sagemaker_session, context_name="foo", context_arn="bazz")

    sagemaker_session.sagemaker_client.list_associations.side_effect = [
        {
            "AssociationSummaries": [
                {
                    "SourceArn": "bazz",
                    "SourceName": "X1",
                    "DestinationArn": "B0",
                    "DestinationName": "Y1",
                    "SourceType": "C1",
                    "DestinationType": "ModelDeployment",
                    "AssociationType": "E1",
                    "CreationTime": None,
                    "CreatedBy": {},
                }
            ],
        },
        {
            "AssociationSummaries": [
                {
                    "SourceArn": "B0",
                    "SourceName": "X2",
                    "DestinationArn": "B2",
                    "DestinationName": "Y2",
                    "SourceType": "C2",
                    "DestinationType": "Model",
                    "AssociationType": "E2",
                    "CreationTime": None,
                    "CreatedBy": {},
                }
            ]
        },
    ]

    model_list = obj.models()

    expected_calls = [
        unittest.mock.call(SourceArn=obj.context_arn, DestinationType="ModelDeployment"),
        unittest.mock.call(SourceArn="B0", DestinationType="Model"),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls

    expected_model_list = [
        _api_types.AssociationSummary(
            source_arn="B0",
            source_name="X2",
            destination_arn="B2",
            destination_name="Y2",
            source_type="C2",
            destination_type="Model",
            association_type="E2",
            creation_time=None,
            created_by={},
        )
    ]
    assert expected_model_list == model_list
def test_trained_models(sagemaker_session):
    model_artifact_obj = artifact.ModelArtifact(
        sagemaker_session, artifact_arn="model-artifact-arn")

    sagemaker_session.sagemaker_client.list_associations.side_effect = [
        {
            "AssociationSummaries": [{
                "SourceArn": model_artifact_obj.artifact_arn,
                "SourceName": "X1",
                "DestinationArn": "action-arn",
                "DestinationName": "Y1",
                "SourceType": "C1",
                "DestinationType": "Action",
                "AssociationType": "E1",
                "CreationTime": None,
                "CreatedBy": {},
            }],
        },
        {
            "AssociationSummaries": [{
                "SourceArn": "action-arn",
                "SourceName": "X2",
                "DestinationArn": "endpoint-context-arn",
                "DestinationName": "Y2",
                "SourceType": "Action",
                "DestinationType": "Context",
                "AssociationType": "E2",
                "CreationTime": None,
                "CreatedBy": {},
            }]
        },
    ]

    endpoint_context_list = model_artifact_obj.endpoints()
    expected_calls = [
        unittest.mock.call(SourceArn=model_artifact_obj.artifact_arn,
                           DestinationType="Action"),
        unittest.mock.call(SourceArn="action-arn", DestinationType="Context"),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls
    expected_model_list = [
        _api_types.AssociationSummary(
            source_arn="action-arn",
            source_name="X2",
            destination_arn="endpoint-context-arn",
            destination_name="Y2",
            source_type="Action",
            destination_type="Context",
            association_type="E2",
            creation_time=None,
            created_by={},
        )
    ]
    assert expected_model_list == endpoint_context_list
def test_list(sagemaker_session):
    creation_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=3)

    sagemaker_session.sagemaker_client.list_associations.side_effect = [
        {
            "AssociationSummaries": [{
                "SourceArn":
                "A" + str(i),
                "SourceName":
                "X" + str(i),
                "DestinationArn":
                "B" + str(i),
                "DestinationName":
                "Y" + str(i),
                "SourceType":
                "C" + str(i),
                "DestinationType":
                "D" + str(i),
                "AssociationType":
                "E" + str(i),
                "CreationTime":
                creation_time + datetime.timedelta(hours=i),
                "CreatedBy": {},
            } for i in range(10)],
            "NextToken":
            "100",
        },
        {
            "AssociationSummaries": [{
                "SourceArn":
                "A" + str(i),
                "SourceName":
                "X" + str(i),
                "DestinationArn":
                "B" + str(i),
                "DestinationName":
                "Y" + str(i),
                "SourceType":
                "C" + str(i),
                "DestinationType":
                "D" + str(i),
                "AssociationType":
                "E" + str(i),
                "CreationTime":
                creation_time + datetime.timedelta(hours=i),
                "CreatedBy": {},
            } for i in range(10, 20)]
        },
    ]

    expected = [
        _api_types.AssociationSummary(
            source_arn="A" + str(i),
            source_name="X" + str(i),
            destination_arn="B" + str(i),
            destination_name="Y" + str(i),
            source_type="C" + str(i),
            destination_type="D" + str(i),
            association_type="E" + str(i),
            creation_time=creation_time + datetime.timedelta(hours=i),
            created_by={},
        ) for i in range(20)
    ]
    result = list(
        association.Association.list(
            sagemaker_session=sagemaker_session,
            source_arn="foo",
            destination_arn="bar",
            sort_by="CreationTime",
            sort_order="Ascending",
        ))

    assert expected == result
    expected_calls = [
        unittest.mock.call(
            SortBy="CreationTime",
            SortOrder="Ascending",
            SourceArn="foo",
            DestinationArn="bar",
        ),
        unittest.mock.call(
            NextToken="100",
            SortBy="CreationTime",
            SortOrder="Ascending",
            SourceArn="foo",
            DestinationArn="bar",
        ),
    ]
    assert expected_calls == sagemaker_session.sagemaker_client.list_associations.mock_calls