def test_whenLogConfigParamsAConfigDictWithSequence_givenAMLFlowCallback_thenLogParams(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.log_config_params(self.settings_in_dictconfig_with_sequence) ml_flow_client_calls = self._populate_calls_from_dict(self.settings_in_dictconfig_with_sequence) ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenOnTestFailure_givenAMLFlowCallback_thenHasFailureTerminatedStatus(self): # pylint: disable=protected-access with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger._status_handling() ml_flow_client_calls = [call().set_terminated(self.a_run_id, status="FAILED")] ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenLogConfigParamsASimpleDict_givenAMLFlowCallback_thenLogParams(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.log_config_params(self.settings_in_dict) ml_flow_client_calls = [] for key, value in self.settings_in_dict.items(): ml_flow_client_calls.append(call().log_param(run_id=self.a_run_id, key=key, value=value)) ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenLogMetric_givenAMLFlowCallback_thenLogMetric(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) ml_flow_client_calls = [] for key, value in self.a_log.items(): mlflow_logger.log_metric(key, value) ml_flow_client_calls.append(call().log_metric(run_id=self.a_run_id, key=key, value=value, step=None)) ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenOnTestSuccess_givenAMLFlowCallback_thenHasSuccessTerminatedStatus(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.set_params({"epochs": self.epochs}) mlflow_logger.on_train_end(self.a_log) mlflow_logger.on_test_begin({}) # since we change status at the start of testing mlflow_logger.on_test_end(self.a_log) ml_flow_client_calls = [call().set_terminated(self.a_run_id, status="FINISHED")] ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenNewExperiment_givenAMLFlowInstantiation_thenCreateNewExperiment(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock MLFlowLogger(self.a_experiment_name) create_experiment_call = [call().create_experiment(self.a_experiment_name, self.none_tracking_uri)] ml_flow_client_patch.assert_has_calls(create_experiment_call)
def test_whenGitRepo_givenAMLFlowInstantiation_thenLogGitCommit(self, get_git_commit_patch): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_run = self.run_mock MLFlowLogger(self.a_experiment_name) git_logging_call = [call(self.the_working_directory)] get_git_commit_patch.assert_has_calls(git_logging_call) mlflow_client_call = [call().set_tag(self.a_run_id, mlflow_default_git_commit_tag, a_git_commit)] ml_flow_client_patch.assert_has_calls(mlflow_client_call)
def test_whenExperimentAlreadyCreated_givenAMLFlowInstantiation_thenGetExperiment(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = MagicMock( side_effect=MlflowException(self.a_exception_message) ) ml_flow_client_patch.return_value.get_experiment_by_name = MagicMock(return_value=self.experiment_mock) MLFlowLogger(self.a_experiment_name) create_experiment_calls = [ call().create_experiment(self.a_experiment_name, self.none_tracking_uri), call().get_experiment_by_name(self.a_experiment_name), ] ml_flow_client_patch.assert_has_calls(create_experiment_calls)
def test_whenOnTrainEndSuccess_givenAMLFlowCallback_thenLogLastEpochNumber(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.set_params({"epochs": self.epochs}) mlflow_logger.on_train_end(self.a_log) ml_flow_client_calls = [ call().log_metric(run_id=self.a_run_id, key='last-epoch', value=self.epochs, step=None) ] ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenCorrectSettings_givenAMLFlowInstantiation_thenMLflowClientIsProperlySet(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) settings_calls = [ call().create_experiment(self.a_experiment_name, self.none_tracking_uri), call().create_run(experiment_id=self.a_experiment_id), ] ml_flow_client_patch.assert_has_calls(settings_calls) actual_experiment_id = mlflow_logger.experiment_id expected_experiment_id = self.a_experiment_id self.assertEqual(expected_experiment_id, actual_experiment_id) actual_run_id = mlflow_logger.run_id expected_run_id = self.a_run_id self.assertEqual(expected_run_id, actual_run_id)