def test_whenLogConfigParamsAConfigDictWithSequence_givenAMLFlowCallback_thenLogParams(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.log_config_params(self.settings_in_dictconfig_with_sequence) ml_flow_client_calls = self._populate_calls_from_dict(self.settings_in_dictconfig_with_sequence) ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)
def test_whenLogConfigParamsASimpleDict_givenAMLFlowCallback_thenLogParams(self): with patch("poutyne.framework.mlflow_logger.MlflowClient") as ml_flow_client_patch: ml_flow_client_patch.return_value.create_experiment = self.experiment_mock ml_flow_client_patch.return_value.create_run = self.run_mock mlflow_logger = MLFlowLogger(self.a_experiment_name) mlflow_logger.log_config_params(self.settings_in_dict) ml_flow_client_calls = [] for key, value in self.settings_in_dict.items(): ml_flow_client_calls.append(call().log_param(run_id=self.a_run_id, key=key, value=value)) ml_flow_client_patch.assert_has_calls(ml_flow_client_calls)