コード例 #1
0
def test_list_trials_call_args(sagemaker_boto_client):
    created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
    created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
    experiment_obj = experiment.Experiment(sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.list_trials.return_value = {}
    assert [] == list(experiment_obj.list_trials(created_after=created_after, created_before=created_before))
    sagemaker_boto_client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after)
コード例 #2
0
def test_list_trials_two_values(sagemaker_boto_client, datetime_obj):
    experiment_obj = experiment.Experiment(
        sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.list_trials.return_value = {
        "TrialSummaries": [
            {
                "Name": "trial-foo-1",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj
            },
            {
                "Name": "trial-foo-2",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj
            },
        ]
    }

    assert list(experiment_obj.list_trials()) == [
        api_types.TrialSummary(name="trial-foo-1",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(name="trial-foo-2",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
    ]
コード例 #3
0
def test_delete(sagemaker_boto_client):
    obj = experiment.Experiment(sagemaker_boto_client,
                                experiment_name='foo',
                                description='bar')
    sagemaker_boto_client.delete_experiment.return_value = {}
    obj.delete()
    sagemaker_boto_client.delete_experiment.assert_called_with(
        ExperimentName='foo')
コード例 #4
0
def test_save(sagemaker_boto_client):
    obj = experiment.Experiment(sagemaker_boto_client,
                                experiment_name="foo",
                                description="bar")
    sagemaker_boto_client.update_experiment.return_value = {}
    obj.save()
    sagemaker_boto_client.update_experiment.assert_called_with(
        ExperimentName="foo", Description="bar")
コード例 #5
0
def test_delete_all_fail(sagemaker_boto_client):
    obj = experiment.Experiment(sagemaker_boto_client,
                                experiment_name="foo",
                                description="bar")
    sagemaker_boto_client.list_trials.side_effect = Exception
    with pytest.raises(Exception) as e:
        obj.delete_all(action="--force")
    assert str(e.value) == "Failed to delete, please try again."
コード例 #6
0
def test_experiment_create_trial_with_name(sagemaker_boto_client):
    experiment_obj = experiment.Experiment(
        sagemaker_boto_client=sagemaker_boto_client)
    experiment_obj.experiment_name = "someExperimentName"
    sagemaker_boto_client.create_trial.return_value = {
        "Arn": "arn:aws:1234",
        "TrialName": "someTrialName",
    }
    experiment_obj.create_trial(trial_name="someTrialName")
    sagemaker_boto_client.create_trial.assert_called_with(
        TrialName="someTrialName", ExperimentName="someExperimentName")
コード例 #7
0
def test_experiment_create_trial_with_prefix(sagemaker_boto_client):
    experiment_obj = experiment.Experiment(sagemaker_boto_client=sagemaker_boto_client)
    experiment_obj.experiment_name = "someExperimentName"
    sagemaker_boto_client.create_trial.return_value = {
        "Arn": "arn:aws:1234",
        "TrialName": "someTrialName1234",
    }
    experiment_obj.create_trial(trial_name_prefix="someTrialName")
    _, _, kwargs = sagemaker_boto_client.mock_calls[0]
    assert kwargs["ExperimentName"] == "someExperimentName"
    assert kwargs["TrialName"].startswith("someTrialName")
コード例 #8
0
def test_next_token(sagemaker_boto_client, datetime_obj):
    experiment_obj = experiment.Experiment(sagemaker_boto_client)
    sagemaker_boto_client.list_trials.side_effect = [
        {
            "TrialSummaries": [
                {
                    "Name": "trial-foo-1",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "Name": "trial-foo-2",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ],
            "NextToken":
            "foo",
        },
        {
            "TrialSummaries": [{
                "Name": "trial-foo-3",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            }]
        },
    ]

    assert list(experiment_obj.list_trials()) == [
        api_types.TrialSummary(name="trial-foo-1",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(name="trial-foo-2",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
        api_types.TrialSummary(name="trial-foo-3",
                               creation_time=datetime_obj,
                               last_modified_time=datetime_obj),
    ]

    sagemaker_boto_client.list_trials.assert_any_call(**{})
    sagemaker_boto_client.list_trials.assert_any_call(NextToken="foo")
コード例 #9
0
def test_list_trials_empty(sagemaker_boto_client):
    sagemaker_boto_client.list_trials.return_value = {"TrialSummaries": []}
    experiment_obj = experiment.Experiment(
        sagemaker_boto_client=sagemaker_boto_client)
    assert list(experiment_obj.list_trials()) == []
コード例 #10
0
def test_delete_all(sagemaker_boto_client):
    obj = experiment.Experiment(sagemaker_boto_client, experiment_name="foo", description="bar")
    sagemaker_boto_client.list_trials.return_value = {
        "TrialSummaries": [
            {"TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
            {"TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
        ]
    }
    sagemaker_boto_client.describe_trial.side_effect = [
        {"Trialname": "trial-1", "ExperimentName": "experiment-name-value"},
        {"Trialname": "trial-2", "ExperimentName": "experiment-name-value"},
    ]
    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "trial-component-1",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-2",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ]
        },
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "trial-component-3",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-4",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ]
        },
    ]

    sagemaker_boto_client.describe_trial_component.side_effect = [
        {"TrialComponentName": "trial-component-1"},
        {"TrialComponentName": "trial-component-2"},
        {"TrialComponentName": "trial-component-3"},
        {"TrialComponentName": "trial-component-4"},
    ]

    sagemaker_boto_client.delete_trial_component.return_value = {}
    sagemaker_boto_client.delete_trial.return_value = {}
    sagemaker_boto_client.delete_experiment.return_value = {}

    obj.delete_all(action="--force")

    sagemaker_boto_client.delete_experiment.assert_called_with(ExperimentName="foo")

    delete_trial_expected_calls = [
        unittest.mock.call(TrialName="trial-1"),
        unittest.mock.call(TrialName="trial-2"),
    ]
    assert delete_trial_expected_calls == sagemaker_boto_client.delete_trial.mock_calls

    delete_trial_component_expected_calls = [
        unittest.mock.call(TrialComponentName="trial-component-1"),
        unittest.mock.call(TrialComponentName="trial-component-2"),
        unittest.mock.call(TrialComponentName="trial-component-3"),
        unittest.mock.call(TrialComponentName="trial-component-4"),
    ]
    assert delete_trial_component_expected_calls == sagemaker_boto_client.delete_trial_component.mock_calls
コード例 #11
0
def test_delete_all_with_incorrect_action_name(sagemaker_boto_client):
    obj = experiment.Experiment(sagemaker_boto_client, experiment_name="foo", description="bar")
    with pytest.raises(ValueError):
        obj.delete_all(action="abc")