Beispiel #1
0
def test_list_trials_with_trial_component_name(sagemaker_boto_client,
                                               datetime_obj):
    sagemaker_boto_client.list_trials.return_value = {
        "TrialSummaries": [
            {
                "TrialName": "trial-1",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            },
            {
                "TrialName": "trial-2",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            },
        ]
    }
    expected = [
        api_types.TrialSummary(trial_name="trial-1",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(trial_name="trial-2",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
    ]
    assert expected == list(
        trial.Trial.list(trial_component_name="tc-foo",
                         sagemaker_boto_client=sagemaker_boto_client))
    sagemaker_boto_client.list_trials.assert_called_with(
        TrialComponentName="tc-foo")
def test_list_trials_two_values(sagemaker_boto_client, datetime_obj):
    experiment_obj = experiment.Experiment(
        sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.list_trials.return_value = {
        "TrialSummaries": [
            {
                "Name": "trial-foo-1",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj
            },
            {
                "Name": "trial-foo-2",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj
            },
        ]
    }

    assert list(experiment_obj.list_trials()) == [
        api_types.TrialSummary(name="trial-foo-1",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(name="trial-foo-2",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
    ]
def test_next_token(sagemaker_boto_client, datetime_obj):
    experiment_obj = experiment.Experiment(sagemaker_boto_client)
    sagemaker_boto_client.list_trials.side_effect = [
        {
            "TrialSummaries": [
                {
                    "Name": "trial-foo-1",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "Name": "trial-foo-2",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ],
            "NextToken":
            "foo",
        },
        {
            "TrialSummaries": [{
                "Name": "trial-foo-3",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            }]
        },
    ]

    assert list(experiment_obj.list_trials()) == [
        api_types.TrialSummary(name="trial-foo-1",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(name="trial-foo-2",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(name="trial-foo-3",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
    ]

    sagemaker_boto_client.list_trials.assert_any_call(**{})
    sagemaker_boto_client.list_trials.assert_any_call(NextToken="foo")