def test_remove_trial_component_from_tracker(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) trkr = tracker.Tracker(tc, unittest.mock.Mock(), unittest.mock.Mock()) t.remove_trial_component(trkr) sagemaker_boto_client.disassociate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="tc-foo")
def test_add_trial_component(sagemaker_boto_client): t = trial.Trial(sagemaker_boto_client) t.trial_name = "bar" t.add_trial_component("foo") sagemaker_boto_client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo") tc = trial_component.TrialComponent(trial_component_name="tc-foo", sagemaker_boto_client=sagemaker_boto_client) t.add_trial_component(tc) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="bar", TrialComponentName=tc.trial_component_name ) tc2 = trial_component.TrialComponent(trial_component_name="tc-foo2", sagemaker_boto_client=sagemaker_boto_client) tc_tracker = tracker.Tracker(tc2, unittest.mock.Mock(), unittest.mock.Mock()) t.add_trial_component(tc_tracker) sagemaker_boto_client.associate_trial_component.assert_called_with( TrialName="bar", TrialComponentName=tc2.trial_component_name )
def under_test(trial_component_obj): return tracker.Tracker(trial_component_obj, unittest.mock.Mock(), unittest.mock.Mock())