def test_list_trial_components_two_values(sagemaker_boto_client, datetime_obj): trial_obj = trial.Trial(sagemaker_boto_client=sagemaker_boto_client) sagemaker_boto_client.list_trial_components.return_value = { "TrialComponentSummaries": [ { "TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ] } assert list(trial_obj.list_trial_components()) == [ api_types.TrialComponentSummary( trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj ), api_types.TrialComponentSummary( trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj ), ]
def test_remove_trial_component_from_trial_component_summary(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" tcs = api_types.TrialComponentSummary() tcs.trial_component_name = "tcs-foo" t.remove_trial_component(tcs) sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tcs-foo")
def test_next_token(sagemaker_boto_client, datetime_obj): trial_obj = trial.Trial(sagemaker_boto_client) sagemaker_boto_client.list_trial_components.side_effect = [ { "TrialComponentSummaries": [ { "TrialComponentName": "trial-component-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, { "TrialComponentName": "trial-component-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }, ], "NextToken": "foo", }, { "TrialComponentSummaries": [{ "TrialComponentName": "trial-component-foo-3", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj, }] }, ] assert list(trial_obj.list_trial_components()) == [ api_types.TrialComponentSummary( trial_component_name="trial-component-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialComponentSummary( trial_component_name="trial-component-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj), api_types.TrialComponentSummary( trial_component_name="trial-component-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj), ] sagemaker_boto_client.list_trial_components.assert_any_call(**{}) sagemaker_boto_client.list_trial_components.assert_any_call( NextToken="foo")
def test_list(sagemaker_boto_client): start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1) end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2) creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3) last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4) sagemaker_boto_client.list_trial_components.side_effect = [ { "TrialComponentSummaries": [ { "TrialComponentName": "A" + str(i), "TrialComponentArn": "B" + str(i), "DisplayName": "C" + str(i), "SourceArn": "D" + str(i), "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)}, "StartTime": start_time + datetime.timedelta(hours=i), "EndTime": end_time + datetime.timedelta(hours=i), "CreationTime": creation_time + datetime.timedelta(hours=i), "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), "LastModifiedBy": {}, } for i in range(10) ], "NextToken": "100", }, { "TrialComponentSummaries": [ { "TrialComponentName": "A" + str(i), "TrialComponentArn": "B" + str(i), "DisplayName": "C" + str(i), "SourceArn": "D" + str(i), "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)}, "StartTime": start_time + datetime.timedelta(hours=i), "EndTime": end_time + datetime.timedelta(hours=i), "CreationTime": creation_time + datetime.timedelta(hours=i), "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i), "LastModifiedBy": {}, } for i in range(10, 20) ] }, ] expected = [ api_types.TrialComponentSummary( trial_component_name="A" + str(i), trial_component_arn="B" + str(i), display_name="C" + str(i), source_arn="D" + str(i), status=api_types.TrialComponentStatus(primary_status="InProgress", message="E" + str(i)), start_time=start_time + datetime.timedelta(hours=i), end_time=end_time + datetime.timedelta(hours=i), creation_time=creation_time + datetime.timedelta(hours=i), last_modified_time=last_modified_time + datetime.timedelta(hours=i), last_modified_by={}, ) for i in range(20) ] result = list( trial_component.TrialComponent.list( sagemaker_boto_client=sagemaker_boto_client, source_arn="foo", sort_by="CreationTime", sort_order="Ascending", ) ) assert expected == result expected_calls = [ unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), unittest.mock.call(NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"), ] assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls
def test_list(sagemaker_boto_client): start_time = datetime.datetime.now( datetime.timezone.utc) + datetime.timedelta(hours=1) end_time = datetime.datetime.now( datetime.timezone.utc) + datetime.timedelta(hours=2) creation_time = datetime.datetime.now( datetime.timezone.utc) + datetime.timedelta(hours=3) last_modified_time = datetime.datetime.now( datetime.timezone.utc) + datetime.timedelta(hours=4) sagemaker_boto_client.list_trial_components.side_effect = [ { 'TrialComponentSummaries': [{ 'TrialComponentName': 'A' + str(i), 'TrialComponentArn': 'B' + str(i), 'DisplayName': 'C' + str(i), 'SourceArn': 'D' + str(i), 'Status': { 'PrimaryStatus': 'InProgress', 'Message': 'E' + str(i) }, 'StartTime': start_time + datetime.timedelta(hours=i), 'EndTime': end_time + datetime.timedelta(hours=i), 'CreationTime': creation_time + datetime.timedelta(hours=i), 'LastModifiedTime': last_modified_time + datetime.timedelta(hours=i), 'LastModifiedBy': {} } for i in range(10)], 'NextToken': '100' }, { 'TrialComponentSummaries': [{ 'TrialComponentName': 'A' + str(i), 'TrialComponentArn': 'B' + str(i), 'DisplayName': 'C' + str(i), 'SourceArn': 'D' + str(i), 'Status': { 'PrimaryStatus': 'InProgress', 'Message': 'E' + str(i) }, 'StartTime': start_time + datetime.timedelta(hours=i), 'EndTime': end_time + datetime.timedelta(hours=i), 'CreationTime': creation_time + datetime.timedelta(hours=i), 'LastModifiedTime': last_modified_time + datetime.timedelta(hours=i), 'LastModifiedBy': {} } for i in range(10, 20)] }, ] expected = [ api_types.TrialComponentSummary( trial_component_name='A' + str(i), trial_component_arn='B' + str(i), display_name='C' + str(i), source_arn='D' + str(i), status=api_types.TrialComponentStatus(primary_status='InProgress', message='E' + str(i)), start_time=start_time + datetime.timedelta(hours=i), end_time=end_time + datetime.timedelta(hours=i), creation_time=creation_time + datetime.timedelta(hours=i), last_modified_time=last_modified_time + datetime.timedelta(hours=i), last_modified_by={}) for i in range(20) ] result = list( trial_component.TrialComponent.list( sagemaker_boto_client=sagemaker_boto_client, source_arn='foo', sort_by='CreationTime', sort_order='Ascending')) assert expected == result expected_calls = [ unittest.mock.call(SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo'), unittest.mock.call(NextToken='100', SortBy='CreationTime', SortOrder='Ascending', SourceArn='foo') ] assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls