def test_set_active_run(context, run_id): with patch.object(MlFlow, "_setup"): # Given: a context passed into the __init__ for MlFlow mlf = MlFlow(context) with patch.object(MlFlow, "_start_run") as mock_start_run: # When _set_active_run is called mlf._set_active_run(run_id=run_id) # pylint: disable=protected-access # And: the run is nested if mlf.parent_run_id is not None: # Then: # - the parent run is started if required mock_start_run.assert_any_call(run_id=mlf.parent_run_id, run_name=mlf.run_name) # - mlflow.start_run is called with nested=True mock_start_run.assert_any_call(run_id=run_id, run_name=mlf.run_name, nested=True) # - _start_run is called twice assert mock_start_run.call_count == 2 # And: the run is not nested else: # Then: mock_start_run.assert_called_once_with(run_id=run_id, run_name=mlf.run_name, nested=False)
def test_set_active_run_parent_zero(child_context): # Given: a parent_run_id of zero child_context.resource_config["parent_run_id"] = 0 # : an initialization of the mlflow object mlf = MlFlow(child_context) with patch.object(MlFlow, "_start_run") as mock_start_run: # And _set_active_run is called with run_id mlf._set_active_run(run_id="what-is-an-edge-case") # pylint: disable=protected-access # Then: _start_run_by_id is called with the parent_id mock_start_run.assert_any_call(run_id=mlf.parent_run_id, run_name=mlf.run_name) # And: mlflow.start_run is called with the run_name and nested=True mock_start_run.assert_any_call( run_id="what-is-an-edge-case", run_name=mlf.run_name, nested=True ) # And _start_run is called twice assert mock_start_run.call_count == 2