def test_remove_trial_component_from_trial_component_summary(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" tcs = api_types.TrialComponentSummary() tcs.trial_component_name = "tcs-foo" t.remove_trial_component(tcs) sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tcs-foo")
def test_remove_trial_component_from_tracker(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) trkr = tracker.Tracker(tc, unittest.mock.Mock(), unittest.mock.Mock()) t.remove_trial_component(trkr) sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tc-foo")
def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj): trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) sagemaker_boto_client.list_trial_components.return_value = { "TrialComponentSummaries": [ { "TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ] } assert list(trial_obj.list_trial_components()) == [ api_types.TrialComponentSummary( trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj ), api_types.TrialComponentSummary( trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj ), ]
def test_next_token(sagemaker_boto_client, datetime_obj): trial_obj = trial.Trial(sagemaker_boto_client) sagemaker_boto_client.list_trial_components.side_effect = [ { "TrialComponentSummaries": [ { "TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ], "NextToken": "foo", }, { "TrialComponentSummaries": [{ "TrialComponentName": "trial-component-foo-3", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }] }, ] assert list(trial_obj.list_trial_components()) == [ api_types.TrialComponentSummary( trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialComponentSummary( trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialComponentSummary( trial_component_name="trial-component-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj), ] sagemaker_boto_client.list_trial_components.assert_any_call(**{}) sagemaker_boto_client.list_trial_components.assert_any_call( NextToken="foo")
def test_add_trial_component(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" t.add_trial_component("foo") sagemaker_boto_client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) t.add_trial_component(tc) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="bar", TrialComponentName=tc.trial_component_name ) tc2 = trial_component.TrialComponent(trial_component_name="tc-foo2", sagemaker_boto_client=sagemaker_boto_client) tc_tracker = tracker.Tracker(tc2, unittest.mock.Mock(), unittest.mock.Mock()) t.add_trial_component(tc_tracker) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="bar", TrialComponentName=tc2.trial_component_name )
def test_list_trial_components_call_args(sagemaker_boto_client): created_before = datetime.datetime(1999, 10, 12, 0, 0, 0) created_after = datetime.datetime(1990, 10, 12, 0, 0, 0) trial_name = "foo-trial" next_token = "thetoken" max_results = 99 trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) trial_obj.trial_name = trial_name sagemaker_boto_client.list_trial_components.return_value = {} assert [] == list( trial_obj.list_trial_components( created_after=created_after, created_before=created_before, next_token=next_token, max_results=max_results ) ) sagemaker_boto_client.list_trial_components.assert_called_with( CreatedBefore=created_before, CreatedAfter=created_after, TrialName=trial_name, NextToken=next_token, MaxResults=max_results, )
def test_boto_ignore(): obj = trial.Trial(sagemaker_boto_client, trial_name="foo") assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]
def test_delete(sagemaker_boto_client): obj = trial.Trial(sagemaker_boto_client, trial_name="foo") sagemaker_boto_client.delete_trial.return_value = {} obj.delete() sagemaker_boto_client.delete_trial.assert_called_with(TrialName="foo")
def test_remove_trial_component(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" t.remove_trial_component("foo") sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo")
def test_remove_trial_component(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = 'bar' t.remove_trial_component('foo') sagemaker_boto_client.disassociate_trial_component.assert_called_with( TrialName='bar', TrialComponentName='foo')
def test_delete_all_fail(sagemaker_boto_client): obj = trial.Trial(sagemaker_boto_client, trial_name="foo") sagemaker_boto_client.list_trials.side_effect = Exception with pytest.raises(Exception) as e: obj.delete_all(action="--force") assert str(e.value) == "Failed to delete, please try again."
def test_delete_all(sagemaker_boto_client): obj = trial.Trial(sagemaker_boto_client, trial_name="foo") sagemaker_boto_client.list_trials.return_value = { "TrialSummaries": [ { "TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj }, { "TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj }, ] } sagemaker_boto_client.list_trial_components.side_effect = [ { "TrialComponentSummaries": [ { "TrialComponentName": "trial-component-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-3", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-4", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ] }, ] sagemaker_boto_client.describe_trial_component.side_effect = [ { "TrialComponentName": "trial-component-1" }, { "TrialComponentName": "trial-component-2" }, { "TrialComponentName": "trial-component-3" }, { "TrialComponentName": "trial-component-4" }, ] sagemaker_boto_client.delete_trial_component.return_value = {} sagemaker_boto_client.delete_trial.return_value = {} obj.delete_all(action="--force") sagemaker_boto_client.delete_trial.assert_called_with(TrialName="foo") delete_trial_component_expected_calls = [ unittest.mock.call(TrialComponentName="trial-component-1"), unittest.mock.call(TrialComponentName="trial-component-2"), unittest.mock.call(TrialComponentName="trial-component-3"), unittest.mock.call(TrialComponentName="trial-component-4"), ] assert delete_trial_component_expected_calls == sagemaker_boto_client.delete_trial_component.mock_calls
def test_delete_all_with_incorrect_action_name(sagemaker_boto_client): obj = trial.Trial(sagemaker_boto_client, trial_name="foo") with pytest.raises(ValueError): obj.delete_all(action="abc")
def test_list_trial_components_empty(sagemaker_boto_client): sagemaker_boto_client.list_trial_components.return_value = {"TrialComponentSummaries": []} trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) assert list(trial_obj.list_trial_components()) == []