def test_remove_trial_component_from_tracker(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) trkr = tracker.Tracker(tc, unittest.mock.Mock(), unittest.mock.Mock()) t.remove_trial_component(trkr) sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tc-foo")
def test_delete_with_force_disassociate(sagemaker_boto_client): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar") sagemaker_boto_client.delete_trial_component.return_value = {} sagemaker_boto_client.list_trials.side_effect = [ { "TrialSummaries": [{ "TrialName": "trial-1" }, { "TrialName": "trial-2" }], "NextToken": "a" }, { "TrialSummaries": [{ "TrialName": "trial-3" }, { "TrialName": "trial-4" }] }, ] obj.delete(force_disassociate=True) expected_calls = [ unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"), unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"), unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"), unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"), ] assert expected_calls == sagemaker_boto_client.disassociate_trial_component.mock_calls sagemaker_boto_client.delete_trial_component.assert_called_with( TrialComponentName="foo")
def test_delete(sagemaker_boto_client): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar") sagemaker_boto_client.delete_trial_component.return_value = {} obj.delete() sagemaker_boto_client.delete_trial_component.assert_called_with( TrialComponentName="foo")
def test_save(sagemaker_boto_client): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name='foo', display_name='bar') sagemaker_boto_client.update_trial_component.return_value = {} obj.save() sagemaker_boto_client.update_trial_component.assert_called_with( TrialComponentName='foo', DisplayName='bar')
def test_add_trial_component(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" t.add_trial_component("foo") sagemaker_boto_client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) t.add_trial_component(tc) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="bar", TrialComponentName=tc.trial_component_name ) tc2 = trial_component.TrialComponent(trial_component_name="tc-foo2", sagemaker_boto_client=sagemaker_boto_client) tc_tracker = tracker.Tracker(tc2, unittest.mock.Mock(), unittest.mock.Mock()) t.add_trial_component(tc_tracker) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="bar", TrialComponentName=tc2.trial_component_name )
def test_load_in_sagemaker_processing_job(mocked_tce, sagemaker_boto_client): trial_component_obj = trial_component.TrialComponent( trial_component_name="foo-bar", sagemaker_boto_client=sagemaker_boto_client ) rv = unittest.mock.Mock() rv.source_arn = "arn:1234" rv.environment_type = _environment.EnvironmentType.SageMakerProcessingJob rv.get_trial_component.return_value = trial_component_obj mocked_tce.load.return_value = rv tracker_obj = tracker.Tracker.load(sagemaker_boto_client=sagemaker_boto_client) assert tracker_obj._in_sagemaker_job assert tracker_obj._metrics_writer is None assert tracker_obj.trial_component == trial_component_obj
def test_save(sagemaker_boto_client): obj = trial_component.TrialComponent( sagemaker_boto_client, trial_component_name="foo", display_name="bar", parameters_to_remove=["E"], input_artifacts_to_remove=["F"], output_artifacts_to_remove=["G"], ) sagemaker_boto_client.update_trial_component.return_value = {} obj.save() sagemaker_boto_client.update_trial_component.assert_called_with( TrialComponentName="foo", DisplayName="bar", ParametersToRemove=["E"], InputArtifactsToRemove=["F"], OutputArtifactsToRemove=["G"], )
def test_create_with_trial_components(sagemaker_boto_client): sagemaker_boto_client.create_trial.return_value = { "Arn": "arn:aws:1234", "TrialName": "name-value", } tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) trial_obj = trial.Trial.create( trial_name="name-value", experiment_name="experiment-name-value", trial_components=[tc], sagemaker_boto_client=sagemaker_boto_client, ) assert trial_obj.trial_name == "name-value" sagemaker_boto_client.create_trial.assert_called_with( TrialName="name-value", ExperimentName="experiment-name-value" ) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="name-value", TrialComponentName=tc.trial_component_name )
def test_boto_ignore(): obj = trial_component.TrialComponent(sagemaker_boto_client, trial_component_name="foo", display_name="bar") assert obj._boto_ignore() == ["ResponseMetadata", "CreatedBy"]
def trial_component_obj(sagemaker_boto_client): return trial_component.TrialComponent(sagemaker_boto_client)