def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj):
    trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.list_trial_components.return_value = {
        "TrialComponentSummaries": [
            {
                "TrialComponentName": "trial-component-foo-1",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            },
            {
                "TrialComponentName": "trial-component-foo-2",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            },
        ]
    }

    assert list(trial_obj.list_trial_components()) == [
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj
        ),
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj
        ),
    ]
示例#2
0
def test_remove_trial_component_from_trial_component_summary(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    tcs = api_types.TrialComponentSummary()
    tcs.trial_component_name = "tcs-foo"
    t.remove_trial_component(tcs)
    sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tcs-foo")
def test_next_token(sagemaker_boto_client, datetime_obj):
    trial_obj = trial.Trial(sagemaker_boto_client)
    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "trial-component-foo-1",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-foo-2",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ],
            "NextToken":
            "foo",
        },
        {
            "TrialComponentSummaries": [{
                "TrialComponentName": "trial-component-foo-3",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            }]
        },
    ]

    assert list(trial_obj.list_trial_components()) == [
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-1",
            creation_time=datetime_obj,
            last_modified_time=datetime_obj),
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-2",
            creation_time=datetime_obj,
            last_modified_time=datetime_obj),
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-3",
            creation_time=datetime_obj,
            last_modified_time=datetime_obj),
    ]

    sagemaker_boto_client.list_trial_components.assert_any_call(**{})
    sagemaker_boto_client.list_trial_components.assert_any_call(
        NextToken="foo")
示例#4
0
def test_list(sagemaker_boto_client):
    start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
    end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2)
    creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3)
    last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4)

    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "A" + str(i),
                    "TrialComponentArn": "B" + str(i),
                    "DisplayName": "C" + str(i),
                    "SourceArn": "D" + str(i),
                    "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)},
                    "StartTime": start_time + datetime.timedelta(hours=i),
                    "EndTime": end_time + datetime.timedelta(hours=i),
                    "CreationTime": creation_time + datetime.timedelta(hours=i),
                    "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i),
                    "LastModifiedBy": {},
                }
                for i in range(10)
            ],
            "NextToken": "100",
        },
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "A" + str(i),
                    "TrialComponentArn": "B" + str(i),
                    "DisplayName": "C" + str(i),
                    "SourceArn": "D" + str(i),
                    "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)},
                    "StartTime": start_time + datetime.timedelta(hours=i),
                    "EndTime": end_time + datetime.timedelta(hours=i),
                    "CreationTime": creation_time + datetime.timedelta(hours=i),
                    "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i),
                    "LastModifiedBy": {},
                }
                for i in range(10, 20)
            ]
        },
    ]

    expected = [
        api_types.TrialComponentSummary(
            trial_component_name="A" + str(i),
            trial_component_arn="B" + str(i),
            display_name="C" + str(i),
            source_arn="D" + str(i),
            status=api_types.TrialComponentStatus(primary_status="InProgress", message="E" + str(i)),
            start_time=start_time + datetime.timedelta(hours=i),
            end_time=end_time + datetime.timedelta(hours=i),
            creation_time=creation_time + datetime.timedelta(hours=i),
            last_modified_time=last_modified_time + datetime.timedelta(hours=i),
            last_modified_by={},
        )
        for i in range(20)
    ]
    result = list(
        trial_component.TrialComponent.list(
            sagemaker_boto_client=sagemaker_boto_client,
            source_arn="foo",
            sort_by="CreationTime",
            sort_order="Ascending",
        )
    )

    assert expected == result
    expected_calls = [
        unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"),
        unittest.mock.call(NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"),
    ]
    assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls
def test_list(sagemaker_boto_client):
    start_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=1)
    end_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=2)
    creation_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=3)
    last_modified_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=4)

    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            'TrialComponentSummaries': [{
                'TrialComponentName':
                'A' + str(i),
                'TrialComponentArn':
                'B' + str(i),
                'DisplayName':
                'C' + str(i),
                'SourceArn':
                'D' + str(i),
                'Status': {
                    'PrimaryStatus': 'InProgress',
                    'Message': 'E' + str(i)
                },
                'StartTime':
                start_time + datetime.timedelta(hours=i),
                'EndTime':
                end_time + datetime.timedelta(hours=i),
                'CreationTime':
                creation_time + datetime.timedelta(hours=i),
                'LastModifiedTime':
                last_modified_time + datetime.timedelta(hours=i),
                'LastModifiedBy': {}
            } for i in range(10)],
            'NextToken':
            '100'
        },
        {
            'TrialComponentSummaries': [{
                'TrialComponentName':
                'A' + str(i),
                'TrialComponentArn':
                'B' + str(i),
                'DisplayName':
                'C' + str(i),
                'SourceArn':
                'D' + str(i),
                'Status': {
                    'PrimaryStatus': 'InProgress',
                    'Message': 'E' + str(i)
                },
                'StartTime':
                start_time + datetime.timedelta(hours=i),
                'EndTime':
                end_time + datetime.timedelta(hours=i),
                'CreationTime':
                creation_time + datetime.timedelta(hours=i),
                'LastModifiedTime':
                last_modified_time + datetime.timedelta(hours=i),
                'LastModifiedBy': {}
            } for i in range(10, 20)]
        },
    ]

    expected = [
        api_types.TrialComponentSummary(
            trial_component_name='A' + str(i),
            trial_component_arn='B' + str(i),
            display_name='C' + str(i),
            source_arn='D' + str(i),
            status=api_types.TrialComponentStatus(primary_status='InProgress',
                                                  message='E' + str(i)),
            start_time=start_time + datetime.timedelta(hours=i),
            end_time=end_time + datetime.timedelta(hours=i),
            creation_time=creation_time + datetime.timedelta(hours=i),
            last_modified_time=last_modified_time +
            datetime.timedelta(hours=i),
            last_modified_by={}) for i in range(20)
    ]
    result = list(
        trial_component.TrialComponent.list(
            sagemaker_boto_client=sagemaker_boto_client,
            source_arn='foo',
            sort_by='CreationTime',
            sort_order='Ascending'))
    assert expected == result
    expected_calls = [
        unittest.mock.call(SortBy='CreationTime',
                           SortOrder='Ascending',
                           SourceArn='foo'),
        unittest.mock.call(NextToken='100',
                           SortBy='CreationTime',
                           SortOrder='Ascending',
                           SourceArn='foo')
    ]
    assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls