示例#1
0
 def __exit__(self, exc_type, exc_value, exc_traceback):
     self._end_time = datetime.datetime.now(dateutil.tz.tzlocal())
     if not self._in_sagemaker_job:
         self.trial_component.end_time = self._end_time
         if exc_value:
             self.trial_component.status = api_types.TrialComponentStatus(
                 primary_status="Failed", message=str(exc_value))
         else:
             self.trial_component.status = api_types.TrialComponentStatus(
                 primary_status="Completed")
     self.close()
示例#2
0
 def __enter__(self):
     self._start_time = datetime.datetime.now(dateutil.tz.tzlocal())
     if not self._in_sagemaker_job:
         self.trial_component.start_time = self._start_time
         self.trial_component.status = api_types.TrialComponentStatus(
             primary_status="InProgress")
     return self
示例#3
0
def test_save(trial_component_obj, sagemaker_boto_client):
    trial_component_obj.display_name = str(uuid.uuid4())
    trial_component_obj.status = api_types.TrialComponentStatus(
        primary_status="InProgress", message="Message")
    trial_component_obj.start_time = datetime.datetime.now(
        datetime.timezone.utc) - datetime.timedelta(days=1)
    trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc)
    trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1}
    trial_component_obj.input_artifacts = {
        "snizz":
        api_types.TrialComponentArtifact(value="s3:/foo/bar",
                                         media_type="text/plain")
    }
    trial_component_obj.output_artifacts = {
        "fly":
        api_types.TrialComponentArtifact(value="s3:/sky/far",
                                         media_type="away/tomorrow")
    }
    trial_component_obj.save()

    loaded = trial_component.TrialComponent.load(
        trial_component_name=trial_component_obj.trial_component_name,
        sagemaker_boto_client=sagemaker_boto_client)

    assert trial_component_obj.trial_component_name == loaded.trial_component_name
    assert trial_component_obj.status == loaded.status

    assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(
        seconds=1)
    assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(
        seconds=1)

    assert trial_component_obj.parameters == loaded.parameters
    assert trial_component_obj.input_artifacts == loaded.input_artifacts
    assert trial_component_obj.output_artifacts == loaded.output_artifacts
def test_save(trial_component_obj, sagemaker_boto_client):
    trial_component_obj.display_name = str(uuid.uuid4())
    trial_component_obj.status = api_types.TrialComponentStatus(primary_status='InProgress', message='Message')
    trial_component_obj.start_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1)
    trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc)
    trial_component_obj.parameters = {
        'foo': 'bar',
        'whizz': 100.1
    }
    trial_component_obj.input_artifacts = {
        'snizz': api_types.TrialComponentArtifact(value='s3:/foo/bar', media_type='text/plain')
    }
    trial_component_obj.output_artifacts = {
        'fly': api_types.TrialComponentArtifact(value='s3:/sky/far', media_type='away/tomorrow')
    }
    trial_component_obj.save()

    loaded = trial_component.TrialComponent.load(trial_component_name = trial_component_obj.trial_component_name,
                                                 sagemaker_boto_client=sagemaker_boto_client)

    assert trial_component_obj.trial_component_name == loaded.trial_component_name
    assert trial_component_obj.status == loaded.status

    assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1)
    assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1)

    assert trial_component_obj.parameters == loaded.parameters
    assert trial_component_obj.input_artifacts == loaded.input_artifacts
    assert trial_component_obj.output_artifacts == loaded.output_artifacts
示例#5
0
def test_load(sagemaker_boto_client):
    now = datetime.datetime.now(datetime.timezone.utc)

    sagemaker_boto_client.describe_trial_component.return_value = {
        "TrialComponentArn": "A",
        "TrialComponentName": "B",
        "DisplayName": "C",
        "Status": {"PrimaryStatus": "InProgress", "Message": "D"},
        "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}},
        "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}},
        "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}},
        "Metrics": [
            {
                "MetricName": "J",
                "Count": 1,
                "Min": 1.0,
                "Max": 2.0,
                "Avg": 3.0,
                "StdDev": 4.0,
                "SourceArn": "K",
                "Timestamp": now,
            }
        ],
    }
    obj = trial_component.TrialComponent.load(trial_component_name="foo", sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.describe_trial_component.assert_called_with(TrialComponentName="foo")
    assert "A" == obj.trial_component_arn
    assert "B" == obj.trial_component_name
    assert "C" == obj.display_name
    assert api_types.TrialComponentStatus(primary_status="InProgress", message="D") == obj.status
    assert {"E": 1.0, "F": "G"} == obj.parameters
    assert {"H": api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")}
    assert {"I": api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain")}
    assert [
        api_types.TrialComponentMetricSummary(
            metric_name="J", count=1, min=1.0, max=2.0, avg=3.0, std_dev=4.0, source_arn="K", timestamp=now
        )
    ]
示例#6
0
def test_list(sagemaker_boto_client):
    start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
    end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2)
    creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3)
    last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4)

    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "A" + str(i),
                    "TrialComponentArn": "B" + str(i),
                    "DisplayName": "C" + str(i),
                    "SourceArn": "D" + str(i),
                    "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)},
                    "StartTime": start_time + datetime.timedelta(hours=i),
                    "EndTime": end_time + datetime.timedelta(hours=i),
                    "CreationTime": creation_time + datetime.timedelta(hours=i),
                    "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i),
                    "LastModifiedBy": {},
                }
                for i in range(10)
            ],
            "NextToken": "100",
        },
        {
            "TrialComponentSummaries": [
                {
                    "TrialComponentName": "A" + str(i),
                    "TrialComponentArn": "B" + str(i),
                    "DisplayName": "C" + str(i),
                    "SourceArn": "D" + str(i),
                    "Status": {"PrimaryStatus": "InProgress", "Message": "E" + str(i)},
                    "StartTime": start_time + datetime.timedelta(hours=i),
                    "EndTime": end_time + datetime.timedelta(hours=i),
                    "CreationTime": creation_time + datetime.timedelta(hours=i),
                    "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i),
                    "LastModifiedBy": {},
                }
                for i in range(10, 20)
            ]
        },
    ]

    expected = [
        api_types.TrialComponentSummary(
            trial_component_name="A" + str(i),
            trial_component_arn="B" + str(i),
            display_name="C" + str(i),
            source_arn="D" + str(i),
            status=api_types.TrialComponentStatus(primary_status="InProgress", message="E" + str(i)),
            start_time=start_time + datetime.timedelta(hours=i),
            end_time=end_time + datetime.timedelta(hours=i),
            creation_time=creation_time + datetime.timedelta(hours=i),
            last_modified_time=last_modified_time + datetime.timedelta(hours=i),
            last_modified_by={},
        )
        for i in range(20)
    ]
    result = list(
        trial_component.TrialComponent.list(
            sagemaker_boto_client=sagemaker_boto_client,
            source_arn="foo",
            sort_by="CreationTime",
            sort_order="Ascending",
        )
    )

    assert expected == result
    expected_calls = [
        unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"),
        unittest.mock.call(NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"),
    ]
    assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls
def test_load(sagemaker_boto_client):
    now = datetime.datetime.now(datetime.timezone.utc)

    sagemaker_boto_client.describe_trial_component.return_value = {
        'TrialComponentArn':
        'A',
        'TrialComponentName':
        'B',
        'DisplayName':
        'C',
        'Status': {
            'PrimaryStatus': 'InProgress',
            'Message': 'D'
        },
        'Parameters': {
            'E': {
                'NumberValue': 1.0
            },
            'F': {
                'StringValue': 'G'
            }
        },
        'InputArtifacts': {
            'H': {
                'Value': 's3://foo/bar',
                'MediaType': 'text/plain'
            }
        },
        'OutputArtifacts': {
            'I': {
                'Value': 's3://whizz/bang',
                'MediaType': 'text/plain'
            }
        },
        'Metrics': [{
            'MetricName': 'J',
            'Count': 1,
            'Min': 1.0,
            'Max': 2.0,
            'Avg': 3.0,
            'StdDev': 4.0,
            'SourceArn': 'K',
            'Timestamp': now
        }]
    }
    obj = trial_component.TrialComponent.load(
        trial_component_name='foo',
        sagemaker_boto_client=sagemaker_boto_client)
    sagemaker_boto_client.describe_trial_component.assert_called_with(
        TrialComponentName='foo')
    assert 'A' == obj.trial_component_arn
    assert 'B' == obj.trial_component_name
    assert 'C' == obj.display_name
    assert api_types.TrialComponentStatus(primary_status='InProgress',
                                          message='D') == obj.status
    assert {'E': 1.0, 'F': 'G'} == obj.parameters
    assert {
        'H':
        api_types.TrialComponentArtifact(value='s3://foo/bar',
                                         media_type='text/plain')
    }
    assert {
        'I':
        api_types.TrialComponentArtifact(value='s3://whizz/bang',
                                         media_type='text/plain')
    }
    assert [
        api_types.TrialComponentMetricSummary(metric_name='J',
                                              count=1,
                                              min=1.0,
                                              max=2.0,
                                              avg=3.0,
                                              std_dev=4.0,
                                              source_arn='K',
                                              timestamp=now)
    ]
def test_list(sagemaker_boto_client):
    start_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=1)
    end_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=2)
    creation_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=3)
    last_modified_time = datetime.datetime.now(
        datetime.timezone.utc) + datetime.timedelta(hours=4)

    sagemaker_boto_client.list_trial_components.side_effect = [
        {
            'TrialComponentSummaries': [{
                'TrialComponentName':
                'A' + str(i),
                'TrialComponentArn':
                'B' + str(i),
                'DisplayName':
                'C' + str(i),
                'SourceArn':
                'D' + str(i),
                'Status': {
                    'PrimaryStatus': 'InProgress',
                    'Message': 'E' + str(i)
                },
                'StartTime':
                start_time + datetime.timedelta(hours=i),
                'EndTime':
                end_time + datetime.timedelta(hours=i),
                'CreationTime':
                creation_time + datetime.timedelta(hours=i),
                'LastModifiedTime':
                last_modified_time + datetime.timedelta(hours=i),
                'LastModifiedBy': {}
            } for i in range(10)],
            'NextToken':
            '100'
        },
        {
            'TrialComponentSummaries': [{
                'TrialComponentName':
                'A' + str(i),
                'TrialComponentArn':
                'B' + str(i),
                'DisplayName':
                'C' + str(i),
                'SourceArn':
                'D' + str(i),
                'Status': {
                    'PrimaryStatus': 'InProgress',
                    'Message': 'E' + str(i)
                },
                'StartTime':
                start_time + datetime.timedelta(hours=i),
                'EndTime':
                end_time + datetime.timedelta(hours=i),
                'CreationTime':
                creation_time + datetime.timedelta(hours=i),
                'LastModifiedTime':
                last_modified_time + datetime.timedelta(hours=i),
                'LastModifiedBy': {}
            } for i in range(10, 20)]
        },
    ]

    expected = [
        api_types.TrialComponentSummary(
            trial_component_name='A' + str(i),
            trial_component_arn='B' + str(i),
            display_name='C' + str(i),
            source_arn='D' + str(i),
            status=api_types.TrialComponentStatus(primary_status='InProgress',
                                                  message='E' + str(i)),
            start_time=start_time + datetime.timedelta(hours=i),
            end_time=end_time + datetime.timedelta(hours=i),
            creation_time=creation_time + datetime.timedelta(hours=i),
            last_modified_time=last_modified_time +
            datetime.timedelta(hours=i),
            last_modified_by={}) for i in range(20)
    ]
    result = list(
        trial_component.TrialComponent.list(
            sagemaker_boto_client=sagemaker_boto_client,
            source_arn='foo',
            sort_by='CreationTime',
            sort_order='Ascending'))
    assert expected == result
    expected_calls = [
        unittest.mock.call(SortBy='CreationTime',
                           SortOrder='Ascending',
                           SourceArn='foo'),
        unittest.mock.call(NextToken='100',
                           SortBy='CreationTime',
                           SortOrder='Ascending',
                           SourceArn='foo')
    ]
    assert expected_calls == sagemaker_boto_client.list_trial_components.mock_calls