Exemplo n.º 1
0
def test_remove_trial_component_from_trial_component_summary(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    tcs = api_types.TrialComponentSummary()
    tcs.trial_component_name = "tcs-foo"
    t.remove_trial_component(tcs)
    sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tcs-foo")
Exemplo n.º 2
0
def test_remove_trial_component_from_tracker(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
    trkr = tracker.Tracker(tc, unittest.mock.Mock(), unittest.mock.Mock())
    t.remove_trial_component(trkr)
    sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tc-foo")
def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj):
    trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.list_trial_components.return_value = {
        "TrialComponentSummaries": [
            {
                "TrialComponentName": "trial-component-foo-1",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            },
            {
                "TrialComponentName": "trial-component-foo-2",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            },
        ]
    }

    assert list(trial_obj.list_trial_components()) == [
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj
        ),
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj
        ),
    ]
def test_next_token(sagemaker_boto_client, datetime_obj):
    trial_obj = trial.Trial(sagemaker_boto_client)
    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "trial-component-foo-1",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-foo-2",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ],
            "NextToken":
            "foo",
        },
        {
            "TrialComponentSummaries": [{
                "TrialComponentName": "trial-component-foo-3",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj,
            }]
        },
    ]

    assert list(trial_obj.list_trial_components()) == [
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-1",
            creation_time=datetime_obj,
            last_modified_time=datetime_obj),
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-2",
            creation_time=datetime_obj,
            last_modified_time=datetime_obj),
        api_types.TrialComponentSummary(
            trial_component_name="trial-component-foo-3",
            creation_time=datetime_obj,
            last_modified_time=datetime_obj),
    ]

    sagemaker_boto_client.list_trial_components.assert_any_call(**{})
    sagemaker_boto_client.list_trial_components.assert_any_call(
        NextToken="foo")
Exemplo n.º 5
0
def test_add_trial_component(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    t.add_trial_component("foo")
    sagemaker_boto_client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo")

    tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
    t.add_trial_component(tc)
    sagemaker_boto_client.associate_trial_component.assert_called_with(
        TrialName="bar", TrialComponentName=tc.trial_component_name
    )

    tc2 = trial_component.TrialComponent(trial_component_name="tc-foo2", sagemaker_boto_client=sagemaker_boto_client)
    tc_tracker = tracker.Tracker(tc2, unittest.mock.Mock(), unittest.mock.Mock())
    t.add_trial_component(tc_tracker)
    sagemaker_boto_client.associate_trial_component.assert_called_with(
        TrialName="bar", TrialComponentName=tc2.trial_component_name
    )
def test_list_trial_components_call_args(sagemaker_boto_client):
    created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
    created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
    trial_name = "foo-trial"
    next_token = "thetoken"
    max_results = 99

    trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
    trial_obj.trial_name = trial_name

    sagemaker_boto_client.list_trial_components.return_value = {}
    assert [] == list(
        trial_obj.list_trial_components(
            created_after=created_after, created_before=created_before, next_token=next_token, max_results=max_results
        )
    )
    sagemaker_boto_client.list_trial_components.assert_called_with(
        CreatedBefore=created_before,
        CreatedAfter=created_after,
        TrialName=trial_name,
        NextToken=next_token,
        MaxResults=max_results,
    )
Exemplo n.º 7
0
def test_boto_ignore():
    obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
    assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]
Exemplo n.º 8
0
def test_delete(sagemaker_boto_client):
    obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
    sagemaker_boto_client.delete_trial.return_value = {}
    obj.delete()
    sagemaker_boto_client.delete_trial.assert_called_with(TrialName="foo")
Exemplo n.º 9
0
def test_remove_trial_component(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    t.remove_trial_component("foo")
    sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo")
def test_remove_trial_component(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = 'bar'
    t.remove_trial_component('foo')
    sagemaker_boto_client.disassociate_trial_component.assert_called_with(
        TrialName='bar', TrialComponentName='foo')
Exemplo n.º 11
0
def test_delete_all_fail(sagemaker_boto_client):
    obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
    sagemaker_boto_client.list_trials.side_effect = Exception
    with pytest.raises(Exception) as e:
        obj.delete_all(action="--force")
    assert str(e.value) == "Failed to delete, please try again."
Exemplo n.º 12
0
def test_delete_all(sagemaker_boto_client):
    obj = trial.Trial(sagemaker_boto_client, trial_name="foo")

    sagemaker_boto_client.list_trials.return_value = {
        "TrialSummaries": [
            {
                "TrialName": "trial-1",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj
            },
            {
                "TrialName": "trial-2",
                "CreationTime": datetime_obj,
                "LastModifiedTime": datetime_obj
            },
        ]
    }

    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "trial-component-1",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-2",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-3",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
                {
                    "TrialComponentName": "trial-component-4",
                    "CreationTime": datetime_obj,
                    "LastModifiedTime": datetime_obj,
                },
            ]
        },
    ]

    sagemaker_boto_client.describe_trial_component.side_effect = [
        {
            "TrialComponentName": "trial-component-1"
        },
        {
            "TrialComponentName": "trial-component-2"
        },
        {
            "TrialComponentName": "trial-component-3"
        },
        {
            "TrialComponentName": "trial-component-4"
        },
    ]

    sagemaker_boto_client.delete_trial_component.return_value = {}
    sagemaker_boto_client.delete_trial.return_value = {}

    obj.delete_all(action="--force")

    sagemaker_boto_client.delete_trial.assert_called_with(TrialName="foo")

    delete_trial_component_expected_calls = [
        unittest.mock.call(TrialComponentName="trial-component-1"),
        unittest.mock.call(TrialComponentName="trial-component-2"),
        unittest.mock.call(TrialComponentName="trial-component-3"),
        unittest.mock.call(TrialComponentName="trial-component-4"),
    ]
    assert delete_trial_component_expected_calls == sagemaker_boto_client.delete_trial_component.mock_calls
Exemplo n.º 13
0
def test_delete_all_with_incorrect_action_name(sagemaker_boto_client):
    obj = trial.Trial(sagemaker_boto_client, trial_name="foo")
    with pytest.raises(ValueError):
        obj.delete_all(action="abc")
def test_list_trial_components_empty(sagemaker_boto_client):
    sagemaker_boto_client.list_trial_components.return_value = {"TrialComponentSummaries": []}
    trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client)
    assert list(trial_obj.list_trial_components()) == []