Example #1
0
def test_remove_trial_component_from_tracker(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
    trkr = tracker.Tracker(tc, unittest.mock.Mock(), unittest.mock.Mock())
    t.remove_trial_component(trkr)
    sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tc-foo")
def test_delete_with_force_disassociate(sagemaker_boto_client):
    obj = trial_component.TrialComponent(sagemaker_boto_client,
                                         trial_component_name="foo",
                                         display_name="bar")
    sagemaker_boto_client.delete_trial_component.return_value = {}

    sagemaker_boto_client.list_trials.side_effect = [
        {
            "TrialSummaries": [{
                "TrialName": "trial-1"
            }, {
                "TrialName": "trial-2"
            }],
            "NextToken": "a"
        },
        {
            "TrialSummaries": [{
                "TrialName": "trial-3"
            }, {
                "TrialName": "trial-4"
            }]
        },
    ]

    obj.delete(force_disassociate=True)
    expected_calls = [
        unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"),
        unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"),
        unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"),
        unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"),
    ]
    assert expected_calls == sagemaker_boto_client.disassociate_trial_component.mock_calls
    sagemaker_boto_client.delete_trial_component.assert_called_with(
        TrialComponentName="foo")
def test_delete(sagemaker_boto_client):
    obj = trial_component.TrialComponent(sagemaker_boto_client,
                                         trial_component_name="foo",
                                         display_name="bar")
    sagemaker_boto_client.delete_trial_component.return_value = {}
    obj.delete()
    sagemaker_boto_client.delete_trial_component.assert_called_with(
        TrialComponentName="foo")
def test_save(sagemaker_boto_client):
    obj = trial_component.TrialComponent(sagemaker_boto_client,
                                         trial_component_name='foo',
                                         display_name='bar')
    sagemaker_boto_client.update_trial_component.return_value = {}
    obj.save()
    sagemaker_boto_client.update_trial_component.assert_called_with(
        TrialComponentName='foo', DisplayName='bar')
Example #5
0
def test_add_trial_component(sagemaker_boto_client):
    t = trial.Trial(sagemaker_boto_client)
    t.trial_name = "bar"
    t.add_trial_component("foo")
    sagemaker_boto_client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo")

    tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)
    t.add_trial_component(tc)
    sagemaker_boto_client.associate_trial_component.assert_called_with(
        TrialName="bar", TrialComponentName=tc.trial_component_name
    )

    tc2 = trial_component.TrialComponent(trial_component_name="tc-foo2", sagemaker_boto_client=sagemaker_boto_client)
    tc_tracker = tracker.Tracker(tc2, unittest.mock.Mock(), unittest.mock.Mock())
    t.add_trial_component(tc_tracker)
    sagemaker_boto_client.associate_trial_component.assert_called_with(
        TrialName="bar", TrialComponentName=tc2.trial_component_name
    )
def test_load_in_sagemaker_processing_job(mocked_tce, sagemaker_boto_client):
    trial_component_obj = trial_component.TrialComponent(
        trial_component_name="foo-bar", sagemaker_boto_client=sagemaker_boto_client
    )

    rv = unittest.mock.Mock()
    rv.source_arn = "arn:1234"
    rv.environment_type = _environment.EnvironmentType.SageMakerProcessingJob
    rv.get_trial_component.return_value = trial_component_obj
    mocked_tce.load.return_value = rv

    tracker_obj = tracker.Tracker.load(sagemaker_boto_client=sagemaker_boto_client)
    assert tracker_obj._in_sagemaker_job
    assert tracker_obj._metrics_writer is None
    assert tracker_obj.trial_component == trial_component_obj
def test_save(sagemaker_boto_client):
    obj = trial_component.TrialComponent(
        sagemaker_boto_client,
        trial_component_name="foo",
        display_name="bar",
        parameters_to_remove=["E"],
        input_artifacts_to_remove=["F"],
        output_artifacts_to_remove=["G"],
    )
    sagemaker_boto_client.update_trial_component.return_value = {}
    obj.save()

    sagemaker_boto_client.update_trial_component.assert_called_with(
        TrialComponentName="foo",
        DisplayName="bar",
        ParametersToRemove=["E"],
        InputArtifactsToRemove=["F"],
        OutputArtifactsToRemove=["G"],
    )
Example #8
0
def test_create_with_trial_components(sagemaker_boto_client):
    sagemaker_boto_client.create_trial.return_value = {
        "Arn": "arn:aws:1234",
        "TrialName": "name-value",
    }
    tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client)

    trial_obj = trial.Trial.create(
        trial_name="name-value",
        experiment_name="experiment-name-value",
        trial_components=[tc],
        sagemaker_boto_client=sagemaker_boto_client,
    )
    assert trial_obj.trial_name == "name-value"
    sagemaker_boto_client.create_trial.assert_called_with(
        TrialName="name-value", ExperimentName="experiment-name-value"
    )
    sagemaker_boto_client.associate_trial_component.assert_called_with(
        TrialName="name-value", TrialComponentName=tc.trial_component_name
    )
Example #9
0
def test_boto_ignore():
    obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar")
    assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]
def trial_component_obj(sagemaker_boto_client):
    return trial_component.TrialComponent(sagemaker_boto_client)