def test_list_trials_with_trial_component_name(sagemaker_boto_client, datetime_obj): sagemaker_boto_client.list_trials.return_value = { "TrialSummaries": [ { "TrialName": "trial-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialName": "trial-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ] } expected = [ api_types.TrialSummary(trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialSummary(trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj), ] assert expected == list( trial.Trial.list(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)) sagemaker_boto_client.list_trials.assert_called_with( TrialComponentName="tc-foo")
def test_list_trials_two_values(sagemaker_boto_client, datetime_obj): experiment_obj = experiment.Experiment( sagemaker_boto_client=sagemaker_boto_client) sagemaker_boto_client.list_trials.return_value = { "TrialSummaries": [ { "Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj }, { "Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj }, ] } assert list(experiment_obj.list_trials()) == [ api_types.TrialSummary(name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialSummary(name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj), ]
def test_next_token(sagemaker_boto_client, datetime_obj): experiment_obj = experiment.Experiment(sagemaker_boto_client) sagemaker_boto_client.list_trials.side_effect = [ { "TrialSummaries": [ { "Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ], "NextToken": "foo", }, { "TrialSummaries": [{ "Name": "trial-foo-3", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }] }, ] assert list(experiment_obj.list_trials()) == [ api_types.TrialSummary(name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialSummary(name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialSummary(name="trial-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj), ] sagemaker_boto_client.list_trials.assert_any_call(**{}) sagemaker_boto_client.list_trials.assert_any_call(NextToken="foo")